Installing required LibrariesΒΆ

InΒ [1]:
pip install keras torch torchvision seaborn tensorflow
Requirement already satisfied: keras in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (3.12.0)
Requirement already satisfied: torch in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (2.8.0+cu128)
Requirement already satisfied: torchvision in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (0.23.0+cu128)
Requirement already satisfied: seaborn in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (0.13.2)
Requirement already satisfied: tensorflow in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (2.20.0)
Requirement already satisfied: absl-py in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (2.3.1)
Requirement already satisfied: numpy in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (1.26.4)
Requirement already satisfied: rich in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (14.2.0)
Requirement already satisfied: namex in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (0.1.0)
Requirement already satisfied: h5py in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (3.15.1)
Requirement already satisfied: optree in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (0.17.0)
Requirement already satisfied: ml-dtypes in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (0.5.3)
Requirement already satisfied: packaging in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (25.0)
Requirement already satisfied: filelock in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (3.20.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (80.9.0)
Requirement already satisfied: sympy>=1.13.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (3.5)
Requirement already satisfied: jinja2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (3.1.6)
Requirement already satisfied: fsspec in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (2025.9.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.93)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.90)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.90)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (11.3.3.83)
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (10.3.9.90)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (11.7.3.90)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.5.8.93)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (2.27.3)
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.90)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.93)
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (1.13.1.3)
Requirement already satisfied: triton==3.4.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (3.4.0)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torchvision) (12.0.0)
Requirement already satisfied: pandas>=1.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from seaborn) (2.1.4)
Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from seaborn) (3.8.2)
Requirement already satisfied: astunparse>=1.6.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (1.6.3)
Requirement already satisfied: flatbuffers>=24.3.25 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (25.9.23)
Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (0.6.0)
Requirement already satisfied: google_pasta>=0.1.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (0.2.0)
Requirement already satisfied: libclang>=13.0.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (18.1.1)
Requirement already satisfied: opt_einsum>=2.3.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (3.4.0)
Requirement already satisfied: protobuf>=5.28.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (6.33.0)
Requirement already satisfied: requests<3,>=2.21.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (2.32.5)
Requirement already satisfied: six>=1.12.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (1.17.0)
Requirement already satisfied: termcolor>=1.1.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (3.2.0)
Requirement already satisfied: wrapt>=1.11.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (2.0.0)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (1.76.0)
Requirement already satisfied: tensorboard~=2.20.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (2.20.0)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2025.10.5)
Requirement already satisfied: markdown>=2.6.8 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (3.9)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (0.7.2)
Requirement already satisfied: werkzeug>=1.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (3.1.3)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from astunparse>=1.6.0->tensorflow) (0.45.1)
Requirement already satisfied: contourpy>=1.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (4.60.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.4.9)
Requirement already satisfied: pyparsing>=2.3.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (3.2.5)
Requirement already satisfied: python-dateutil>=2.7 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pandas>=1.2->seaborn) (2025.2)
Requirement already satisfied: tzdata>=2022.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pandas>=1.2->seaborn) (2025.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.1.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from werkzeug>=1.0.1->tensorboard~=2.20.0->tensorflow) (3.0.3)
Requirement already satisfied: markdown-it-py>=2.2.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from rich->keras) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from rich->keras) (2.19.2)
Requirement already satisfied: mdurl~=0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich->keras) (0.1.2)
Note: you may need to restart the kernel to use updated packages.

Importing librariesΒΆ

Note: Training was done on lightning.ai for better compute speed

InΒ [2]:
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from keras.datasets import mnist, fashion_mnist
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
2025-11-04 10:07:27.689711: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Device: cuda
InΒ [3]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU only")
CUDA available: True
Device name: NVIDIA L40S

Helper functionsΒΆ

InΒ [4]:
torch.manual_seed(42)
np.random.seed(42)

def show_confusion(cm, labels, title='Confusion Matrix'):
    plt.figure(figsize=(7,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.title(title)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.show()

Importing the dataset and reshaping itΒΆ

InΒ [Β ]:
(train_X, train_y), (test_X, test_y) = mnist.load_data()
train_X = train_X.reshape(-1,784)
test_X = test_X.reshape(-1, 784)
train_X = torch.from_numpy(train_X).float().reshape(-1, 784) / 255
train_y = torch.from_numpy(train_y).to(torch.int64)

test_X = torch.from_numpy(test_X).float().reshape(-1,784) / 255
test_y = torch.from_numpy(test_y).to(torch.int64)

Helper functionsΒΆ

InΒ [6]:
def evaluate_model(model, X_test, y_test, device='cpu'):
    model.eval()
    with torch.no_grad():
        X_test = X_test.to(device)
        y_test = y_test.to(device)
        outputs = model(X_test)
        loss = F.cross_entropy(outputs, y_test).item()
        preds = outputs.argmax(dim=1).cpu().numpy()
        y_true = y_test.cpu().numpy()
    acc = accuracy_score(y_true, preds)
    f1 = f1_score(y_true, preds, average='macro')
    cm = confusion_matrix(y_true, preds)
    return preds, acc, f1, cm, loss
InΒ [7]:
def summary(name, acc, f1, cm, train_losses=None):
    print(f"\n{name}")
    print(f"Accuracy: {acc:.4f}, F1-score: {f1:.4f}")
    show_confusion(cm, list(range(10)), title=f"{name} Confusion Matrix")

    if train_losses is not None:
        plt.figure(figsize=(7, 5))
        plt.plot(train_losses, label="Training Loss", linewidth=2)
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.title(f"Training Loss vs Epochs - {name}")
        plt.legend()
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()
InΒ [8]:
def visualize_tsne(model, X, y, trained=True, device='cpu'):
    model.eval()
    X, y = X.to(device), y.cpu().numpy()
    with torch.no_grad():
        x = F.relu(model.fc1(X))
        layer2_out = model.fc2(x).cpu().numpy()

    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(layer2_out)

    plt.figure(figsize=(10, 6))
    num_classes = len(np.unique(y))
    for i in range(num_classes):
        indices = (y == i)
        plt.scatter(
            tsne_results[indices, 0],
            tsne_results[indices, 1],
            label=i,
            alpha=0.5
        )

    plt.legend(title="Class")
    plt.title(f"t-SNE (20-neuron layer) - {'Trained' if trained else 'Untrained'} Model")
    plt.tight_layout()
    plt.show()

Defining the MLPΒΆ

InΒ [10]:
class MLP_relu(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 30)
        self.fc2 = nn.Linear(30, 20)
        self.fc3 = nn.Linear(20, 10)

    def forward(self, x):
        x = x.view(len(x), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
InΒ [11]:
class MLP_sigmoid(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 30)
        self.fc2 = nn.Linear(30, 20)
        self.fc3 = nn.Linear(20, 10)

    def forward(self, x):
        x = x.view(len(x), -1)
        x = F.sigmoid(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        return self.fc3(x)
InΒ [12]:
def train_mlp(model, X_train, y_train, epochs=100, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    train_losses = []
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        outputs = model(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
    return train_losses

Using Cross Entropy Loss and ReLU

InΒ [13]:
mlp_relu = MLP_relu().to(device)
print("\nTraining MLP...")
train_losses = train_mlp(mlp_relu, train_X.to(device), train_y.to(device), epochs=1000)

preds, acc, f1, cm, test_loss = evaluate_model(mlp_relu, test_X, test_y, device)
summary("MLP on MNIST with ReLU", acc, f1, cm, train_losses)
print(f"Test Loss: {test_loss:.4f}")
Training MLP...
Epoch 1/1000, Loss: 2.3110
Epoch 2/1000, Loss: 2.2992
Epoch 3/1000, Loss: 2.2862
Epoch 4/1000, Loss: 2.2712
Epoch 5/1000, Loss: 2.2543
Epoch 6/1000, Loss: 2.2358
Epoch 7/1000, Loss: 2.2162
Epoch 8/1000, Loss: 2.1962
Epoch 9/1000, Loss: 2.1763
Epoch 10/1000, Loss: 2.1564
Epoch 11/1000, Loss: 2.1361
Epoch 12/1000, Loss: 2.1151
Epoch 13/1000, Loss: 2.0936
Epoch 14/1000, Loss: 2.0715
Epoch 15/1000, Loss: 2.0489
Epoch 16/1000, Loss: 2.0259
Epoch 17/1000, Loss: 2.0023
Epoch 18/1000, Loss: 1.9784
Epoch 19/1000, Loss: 1.9540
Epoch 20/1000, Loss: 1.9293
Epoch 21/1000, Loss: 1.9043
Epoch 22/1000, Loss: 1.8790
Epoch 23/1000, Loss: 1.8533
Epoch 24/1000, Loss: 1.8274
Epoch 25/1000, Loss: 1.8012
Epoch 26/1000, Loss: 1.7747
Epoch 27/1000, Loss: 1.7479
Epoch 28/1000, Loss: 1.7210
Epoch 29/1000, Loss: 1.6938
Epoch 30/1000, Loss: 1.6665
Epoch 31/1000, Loss: 1.6391
Epoch 32/1000, Loss: 1.6115
Epoch 33/1000, Loss: 1.5839
Epoch 34/1000, Loss: 1.5562
Epoch 35/1000, Loss: 1.5285
Epoch 36/1000, Loss: 1.5008
Epoch 37/1000, Loss: 1.4731
Epoch 38/1000, Loss: 1.4454
Epoch 39/1000, Loss: 1.4179
Epoch 40/1000, Loss: 1.3904
Epoch 41/1000, Loss: 1.3630
Epoch 42/1000, Loss: 1.3359
Epoch 43/1000, Loss: 1.3090
Epoch 44/1000, Loss: 1.2824
Epoch 45/1000, Loss: 1.2561
Epoch 46/1000, Loss: 1.2302
Epoch 47/1000, Loss: 1.2046
Epoch 48/1000, Loss: 1.1794
Epoch 49/1000, Loss: 1.1546
Epoch 50/1000, Loss: 1.1303
Epoch 51/1000, Loss: 1.1064
Epoch 52/1000, Loss: 1.0831
Epoch 53/1000, Loss: 1.0602
Epoch 54/1000, Loss: 1.0379
Epoch 55/1000, Loss: 1.0162
Epoch 56/1000, Loss: 0.9950
Epoch 57/1000, Loss: 0.9745
Epoch 58/1000, Loss: 0.9545
Epoch 59/1000, Loss: 0.9350
Epoch 60/1000, Loss: 0.9162
Epoch 61/1000, Loss: 0.8978
Epoch 62/1000, Loss: 0.8800
Epoch 63/1000, Loss: 0.8627
Epoch 64/1000, Loss: 0.8459
Epoch 65/1000, Loss: 0.8296
Epoch 66/1000, Loss: 0.8138
Epoch 67/1000, Loss: 0.7984
Epoch 68/1000, Loss: 0.7836
Epoch 69/1000, Loss: 0.7692
Epoch 70/1000, Loss: 0.7553
Epoch 71/1000, Loss: 0.7419
Epoch 72/1000, Loss: 0.7289
Epoch 73/1000, Loss: 0.7163
Epoch 74/1000, Loss: 0.7042
Epoch 75/1000, Loss: 0.6925
Epoch 76/1000, Loss: 0.6811
Epoch 77/1000, Loss: 0.6702
Epoch 78/1000, Loss: 0.6596
Epoch 79/1000, Loss: 0.6494
Epoch 80/1000, Loss: 0.6395
Epoch 81/1000, Loss: 0.6300
Epoch 82/1000, Loss: 0.6209
Epoch 83/1000, Loss: 0.6120
Epoch 84/1000, Loss: 0.6035
Epoch 85/1000, Loss: 0.5953
Epoch 86/1000, Loss: 0.5873
Epoch 87/1000, Loss: 0.5797
Epoch 88/1000, Loss: 0.5723
Epoch 89/1000, Loss: 0.5651
Epoch 90/1000, Loss: 0.5582
Epoch 91/1000, Loss: 0.5515
Epoch 92/1000, Loss: 0.5451
Epoch 93/1000, Loss: 0.5389
Epoch 94/1000, Loss: 0.5329
Epoch 95/1000, Loss: 0.5270
Epoch 96/1000, Loss: 0.5214
Epoch 97/1000, Loss: 0.5160
Epoch 98/1000, Loss: 0.5107
Epoch 99/1000, Loss: 0.5056
Epoch 100/1000, Loss: 0.5006
Epoch 101/1000, Loss: 0.4958
Epoch 102/1000, Loss: 0.4912
Epoch 103/1000, Loss: 0.4866
Epoch 104/1000, Loss: 0.4823
Epoch 105/1000, Loss: 0.4780
Epoch 106/1000, Loss: 0.4739
Epoch 107/1000, Loss: 0.4699
Epoch 108/1000, Loss: 0.4659
Epoch 109/1000, Loss: 0.4621
Epoch 110/1000, Loss: 0.4584
Epoch 111/1000, Loss: 0.4548
Epoch 112/1000, Loss: 0.4513
Epoch 113/1000, Loss: 0.4479
Epoch 114/1000, Loss: 0.4446
Epoch 115/1000, Loss: 0.4413
Epoch 116/1000, Loss: 0.4382
Epoch 117/1000, Loss: 0.4351
Epoch 118/1000, Loss: 0.4321
Epoch 119/1000, Loss: 0.4291
Epoch 120/1000, Loss: 0.4262
Epoch 121/1000, Loss: 0.4234
Epoch 122/1000, Loss: 0.4207
Epoch 123/1000, Loss: 0.4180
Epoch 124/1000, Loss: 0.4154
Epoch 125/1000, Loss: 0.4128
Epoch 126/1000, Loss: 0.4103
Epoch 127/1000, Loss: 0.4079
Epoch 128/1000, Loss: 0.4054
Epoch 129/1000, Loss: 0.4031
Epoch 130/1000, Loss: 0.4008
Epoch 131/1000, Loss: 0.3985
Epoch 132/1000, Loss: 0.3963
Epoch 133/1000, Loss: 0.3941
Epoch 134/1000, Loss: 0.3920
Epoch 135/1000, Loss: 0.3899
Epoch 136/1000, Loss: 0.3879
Epoch 137/1000, Loss: 0.3859
Epoch 138/1000, Loss: 0.3839
Epoch 139/1000, Loss: 0.3819
Epoch 140/1000, Loss: 0.3800
Epoch 141/1000, Loss: 0.3782
Epoch 142/1000, Loss: 0.3763
Epoch 143/1000, Loss: 0.3745
Epoch 144/1000, Loss: 0.3728
Epoch 145/1000, Loss: 0.3710
Epoch 146/1000, Loss: 0.3693
Epoch 147/1000, Loss: 0.3676
Epoch 148/1000, Loss: 0.3660
Epoch 149/1000, Loss: 0.3643
Epoch 150/1000, Loss: 0.3627
Epoch 151/1000, Loss: 0.3611
Epoch 152/1000, Loss: 0.3596
Epoch 153/1000, Loss: 0.3580
Epoch 154/1000, Loss: 0.3565
Epoch 155/1000, Loss: 0.3551
Epoch 156/1000, Loss: 0.3536
Epoch 157/1000, Loss: 0.3521
Epoch 158/1000, Loss: 0.3507
Epoch 159/1000, Loss: 0.3493
Epoch 160/1000, Loss: 0.3479
Epoch 161/1000, Loss: 0.3466
Epoch 162/1000, Loss: 0.3452
Epoch 163/1000, Loss: 0.3439
Epoch 164/1000, Loss: 0.3426
Epoch 165/1000, Loss: 0.3413
Epoch 166/1000, Loss: 0.3400
Epoch 167/1000, Loss: 0.3388
Epoch 168/1000, Loss: 0.3375
Epoch 169/1000, Loss: 0.3363
Epoch 170/1000, Loss: 0.3351
Epoch 171/1000, Loss: 0.3339
Epoch 172/1000, Loss: 0.3327
Epoch 173/1000, Loss: 0.3316
Epoch 174/1000, Loss: 0.3304
Epoch 175/1000, Loss: 0.3293
Epoch 176/1000, Loss: 0.3281
Epoch 177/1000, Loss: 0.3270
Epoch 178/1000, Loss: 0.3259
Epoch 179/1000, Loss: 0.3249
Epoch 180/1000, Loss: 0.3238
Epoch 181/1000, Loss: 0.3227
Epoch 182/1000, Loss: 0.3217
Epoch 183/1000, Loss: 0.3206
Epoch 184/1000, Loss: 0.3196
Epoch 185/1000, Loss: 0.3186
Epoch 186/1000, Loss: 0.3176
Epoch 187/1000, Loss: 0.3166
Epoch 188/1000, Loss: 0.3156
Epoch 189/1000, Loss: 0.3147
Epoch 190/1000, Loss: 0.3137
Epoch 191/1000, Loss: 0.3127
Epoch 192/1000, Loss: 0.3118
Epoch 193/1000, Loss: 0.3109
Epoch 194/1000, Loss: 0.3099
Epoch 195/1000, Loss: 0.3090
Epoch 196/1000, Loss: 0.3081
Epoch 197/1000, Loss: 0.3072
Epoch 198/1000, Loss: 0.3063
Epoch 199/1000, Loss: 0.3055
Epoch 200/1000, Loss: 0.3046
Epoch 201/1000, Loss: 0.3037
Epoch 202/1000, Loss: 0.3029
Epoch 203/1000, Loss: 0.3020
Epoch 204/1000, Loss: 0.3012
Epoch 205/1000, Loss: 0.3004
Epoch 206/1000, Loss: 0.2995
Epoch 207/1000, Loss: 0.2987
Epoch 208/1000, Loss: 0.2979
Epoch 209/1000, Loss: 0.2971
Epoch 210/1000, Loss: 0.2963
Epoch 211/1000, Loss: 0.2955
Epoch 212/1000, Loss: 0.2947
Epoch 213/1000, Loss: 0.2939
Epoch 214/1000, Loss: 0.2931
Epoch 215/1000, Loss: 0.2924
Epoch 216/1000, Loss: 0.2916
Epoch 217/1000, Loss: 0.2908
Epoch 218/1000, Loss: 0.2901
Epoch 219/1000, Loss: 0.2893
Epoch 220/1000, Loss: 0.2886
Epoch 221/1000, Loss: 0.2878
Epoch 222/1000, Loss: 0.2871
Epoch 223/1000, Loss: 0.2864
Epoch 224/1000, Loss: 0.2856
Epoch 225/1000, Loss: 0.2849
Epoch 226/1000, Loss: 0.2842
Epoch 227/1000, Loss: 0.2835
Epoch 228/1000, Loss: 0.2827
Epoch 229/1000, Loss: 0.2820
Epoch 230/1000, Loss: 0.2813
Epoch 231/1000, Loss: 0.2806
Epoch 232/1000, Loss: 0.2799
Epoch 233/1000, Loss: 0.2792
Epoch 234/1000, Loss: 0.2785
Epoch 235/1000, Loss: 0.2778
Epoch 236/1000, Loss: 0.2771
Epoch 237/1000, Loss: 0.2764
Epoch 238/1000, Loss: 0.2757
Epoch 239/1000, Loss: 0.2750
Epoch 240/1000, Loss: 0.2743
Epoch 241/1000, Loss: 0.2736
Epoch 242/1000, Loss: 0.2729
Epoch 243/1000, Loss: 0.2722
Epoch 244/1000, Loss: 0.2715
Epoch 245/1000, Loss: 0.2708
Epoch 246/1000, Loss: 0.2702
Epoch 247/1000, Loss: 0.2695
Epoch 248/1000, Loss: 0.2688
Epoch 249/1000, Loss: 0.2681
Epoch 250/1000, Loss: 0.2674
Epoch 251/1000, Loss: 0.2668
Epoch 252/1000, Loss: 0.2661
Epoch 253/1000, Loss: 0.2654
Epoch 254/1000, Loss: 0.2647
Epoch 255/1000, Loss: 0.2641
Epoch 256/1000, Loss: 0.2634
Epoch 257/1000, Loss: 0.2627
Epoch 258/1000, Loss: 0.2621
Epoch 259/1000, Loss: 0.2614
Epoch 260/1000, Loss: 0.2607
Epoch 261/1000, Loss: 0.2601
Epoch 262/1000, Loss: 0.2594
Epoch 263/1000, Loss: 0.2587
Epoch 264/1000, Loss: 0.2581
Epoch 265/1000, Loss: 0.2574
Epoch 266/1000, Loss: 0.2568
Epoch 267/1000, Loss: 0.2561
Epoch 268/1000, Loss: 0.2555
Epoch 269/1000, Loss: 0.2548
Epoch 270/1000, Loss: 0.2542
Epoch 271/1000, Loss: 0.2535
Epoch 272/1000, Loss: 0.2529
Epoch 273/1000, Loss: 0.2522
Epoch 274/1000, Loss: 0.2516
Epoch 275/1000, Loss: 0.2509
Epoch 276/1000, Loss: 0.2503
Epoch 277/1000, Loss: 0.2496
Epoch 278/1000, Loss: 0.2490
Epoch 279/1000, Loss: 0.2484
Epoch 280/1000, Loss: 0.2477
Epoch 281/1000, Loss: 0.2471
Epoch 282/1000, Loss: 0.2465
Epoch 283/1000, Loss: 0.2459
Epoch 284/1000, Loss: 0.2452
Epoch 285/1000, Loss: 0.2446
Epoch 286/1000, Loss: 0.2440
Epoch 287/1000, Loss: 0.2434
Epoch 288/1000, Loss: 0.2427
Epoch 289/1000, Loss: 0.2421
Epoch 290/1000, Loss: 0.2415
Epoch 291/1000, Loss: 0.2409
Epoch 292/1000, Loss: 0.2403
Epoch 293/1000, Loss: 0.2397
Epoch 294/1000, Loss: 0.2391
Epoch 295/1000, Loss: 0.2385
Epoch 296/1000, Loss: 0.2379
Epoch 297/1000, Loss: 0.2373
Epoch 298/1000, Loss: 0.2367
Epoch 299/1000, Loss: 0.2361
Epoch 300/1000, Loss: 0.2355
Epoch 301/1000, Loss: 0.2349
Epoch 302/1000, Loss: 0.2343
Epoch 303/1000, Loss: 0.2337
Epoch 304/1000, Loss: 0.2331
Epoch 305/1000, Loss: 0.2325
Epoch 306/1000, Loss: 0.2319
Epoch 307/1000, Loss: 0.2313
Epoch 308/1000, Loss: 0.2307
Epoch 309/1000, Loss: 0.2301
Epoch 310/1000, Loss: 0.2295
Epoch 311/1000, Loss: 0.2290
Epoch 312/1000, Loss: 0.2284
Epoch 313/1000, Loss: 0.2278
Epoch 314/1000, Loss: 0.2272
Epoch 315/1000, Loss: 0.2266
Epoch 316/1000, Loss: 0.2260
Epoch 317/1000, Loss: 0.2255
Epoch 318/1000, Loss: 0.2249
Epoch 319/1000, Loss: 0.2243
Epoch 320/1000, Loss: 0.2237
Epoch 321/1000, Loss: 0.2232
Epoch 322/1000, Loss: 0.2226
Epoch 323/1000, Loss: 0.2220
Epoch 324/1000, Loss: 0.2214
Epoch 325/1000, Loss: 0.2209
Epoch 326/1000, Loss: 0.2203
Epoch 327/1000, Loss: 0.2197
Epoch 328/1000, Loss: 0.2191
Epoch 329/1000, Loss: 0.2186
Epoch 330/1000, Loss: 0.2180
Epoch 331/1000, Loss: 0.2174
Epoch 332/1000, Loss: 0.2168
Epoch 333/1000, Loss: 0.2162
Epoch 334/1000, Loss: 0.2156
Epoch 335/1000, Loss: 0.2151
Epoch 336/1000, Loss: 0.2145
Epoch 337/1000, Loss: 0.2139
Epoch 338/1000, Loss: 0.2133
Epoch 339/1000, Loss: 0.2127
Epoch 340/1000, Loss: 0.2121
Epoch 341/1000, Loss: 0.2116
Epoch 342/1000, Loss: 0.2110
Epoch 343/1000, Loss: 0.2104
Epoch 344/1000, Loss: 0.2098
Epoch 345/1000, Loss: 0.2093
Epoch 346/1000, Loss: 0.2087
Epoch 347/1000, Loss: 0.2081
Epoch 348/1000, Loss: 0.2075
Epoch 349/1000, Loss: 0.2070
Epoch 350/1000, Loss: 0.2064
Epoch 351/1000, Loss: 0.2058
Epoch 352/1000, Loss: 0.2053
Epoch 353/1000, Loss: 0.2047
Epoch 354/1000, Loss: 0.2042
Epoch 355/1000, Loss: 0.2036
Epoch 356/1000, Loss: 0.2031
Epoch 357/1000, Loss: 0.2025
Epoch 358/1000, Loss: 0.2020
Epoch 359/1000, Loss: 0.2014
Epoch 360/1000, Loss: 0.2009
Epoch 361/1000, Loss: 0.2004
Epoch 362/1000, Loss: 0.1999
Epoch 363/1000, Loss: 0.1993
Epoch 364/1000, Loss: 0.1988
Epoch 365/1000, Loss: 0.1983
Epoch 366/1000, Loss: 0.1978
Epoch 367/1000, Loss: 0.1973
Epoch 368/1000, Loss: 0.1967
Epoch 369/1000, Loss: 0.1962
Epoch 370/1000, Loss: 0.1957
Epoch 371/1000, Loss: 0.1952
Epoch 372/1000, Loss: 0.1947
Epoch 373/1000, Loss: 0.1942
Epoch 374/1000, Loss: 0.1937
Epoch 375/1000, Loss: 0.1932
Epoch 376/1000, Loss: 0.1927
Epoch 377/1000, Loss: 0.1923
Epoch 378/1000, Loss: 0.1918
Epoch 379/1000, Loss: 0.1913
Epoch 380/1000, Loss: 0.1908
Epoch 381/1000, Loss: 0.1903
Epoch 382/1000, Loss: 0.1898
Epoch 383/1000, Loss: 0.1894
Epoch 384/1000, Loss: 0.1889
Epoch 385/1000, Loss: 0.1884
Epoch 386/1000, Loss: 0.1880
Epoch 387/1000, Loss: 0.1875
Epoch 388/1000, Loss: 0.1870
Epoch 389/1000, Loss: 0.1866
Epoch 390/1000, Loss: 0.1861
Epoch 391/1000, Loss: 0.1857
Epoch 392/1000, Loss: 0.1852
Epoch 393/1000, Loss: 0.1848
Epoch 394/1000, Loss: 0.1843
Epoch 395/1000, Loss: 0.1839
Epoch 396/1000, Loss: 0.1834
Epoch 397/1000, Loss: 0.1830
Epoch 398/1000, Loss: 0.1825
Epoch 399/1000, Loss: 0.1821
Epoch 400/1000, Loss: 0.1817
Epoch 401/1000, Loss: 0.1812
Epoch 402/1000, Loss: 0.1808
Epoch 403/1000, Loss: 0.1804
Epoch 404/1000, Loss: 0.1799
Epoch 405/1000, Loss: 0.1795
Epoch 406/1000, Loss: 0.1791
Epoch 407/1000, Loss: 0.1787
Epoch 408/1000, Loss: 0.1782
Epoch 409/1000, Loss: 0.1778
Epoch 410/1000, Loss: 0.1774
Epoch 411/1000, Loss: 0.1770
Epoch 412/1000, Loss: 0.1766
Epoch 413/1000, Loss: 0.1762
Epoch 414/1000, Loss: 0.1758
Epoch 415/1000, Loss: 0.1753
Epoch 416/1000, Loss: 0.1749
Epoch 417/1000, Loss: 0.1745
Epoch 418/1000, Loss: 0.1741
Epoch 419/1000, Loss: 0.1737
Epoch 420/1000, Loss: 0.1733
Epoch 421/1000, Loss: 0.1729
Epoch 422/1000, Loss: 0.1725
Epoch 423/1000, Loss: 0.1722
Epoch 424/1000, Loss: 0.1718
Epoch 425/1000, Loss: 0.1714
Epoch 426/1000, Loss: 0.1710
Epoch 427/1000, Loss: 0.1706
Epoch 428/1000, Loss: 0.1702
Epoch 429/1000, Loss: 0.1698
Epoch 430/1000, Loss: 0.1694
Epoch 431/1000, Loss: 0.1691
Epoch 432/1000, Loss: 0.1687
Epoch 433/1000, Loss: 0.1683
Epoch 434/1000, Loss: 0.1679
Epoch 435/1000, Loss: 0.1676
Epoch 436/1000, Loss: 0.1672
Epoch 437/1000, Loss: 0.1668
Epoch 438/1000, Loss: 0.1664
Epoch 439/1000, Loss: 0.1661
Epoch 440/1000, Loss: 0.1657
Epoch 441/1000, Loss: 0.1653
Epoch 442/1000, Loss: 0.1650
Epoch 443/1000, Loss: 0.1646
Epoch 444/1000, Loss: 0.1642
Epoch 445/1000, Loss: 0.1639
Epoch 446/1000, Loss: 0.1635
Epoch 447/1000, Loss: 0.1632
Epoch 448/1000, Loss: 0.1628
Epoch 449/1000, Loss: 0.1625
Epoch 450/1000, Loss: 0.1621
Epoch 451/1000, Loss: 0.1618
Epoch 452/1000, Loss: 0.1614
Epoch 453/1000, Loss: 0.1611
Epoch 454/1000, Loss: 0.1607
Epoch 455/1000, Loss: 0.1604
Epoch 456/1000, Loss: 0.1600
Epoch 457/1000, Loss: 0.1597
Epoch 458/1000, Loss: 0.1593
Epoch 459/1000, Loss: 0.1590
Epoch 460/1000, Loss: 0.1587
Epoch 461/1000, Loss: 0.1583
Epoch 462/1000, Loss: 0.1580
Epoch 463/1000, Loss: 0.1577
Epoch 464/1000, Loss: 0.1573
Epoch 465/1000, Loss: 0.1570
Epoch 466/1000, Loss: 0.1567
Epoch 467/1000, Loss: 0.1563
Epoch 468/1000, Loss: 0.1560
Epoch 469/1000, Loss: 0.1557
Epoch 470/1000, Loss: 0.1554
Epoch 471/1000, Loss: 0.1551
Epoch 472/1000, Loss: 0.1547
Epoch 473/1000, Loss: 0.1544
Epoch 474/1000, Loss: 0.1541
Epoch 475/1000, Loss: 0.1538
Epoch 476/1000, Loss: 0.1535
Epoch 477/1000, Loss: 0.1532
Epoch 478/1000, Loss: 0.1529
Epoch 479/1000, Loss: 0.1525
Epoch 480/1000, Loss: 0.1522
Epoch 481/1000, Loss: 0.1519
Epoch 482/1000, Loss: 0.1516
Epoch 483/1000, Loss: 0.1513
Epoch 484/1000, Loss: 0.1510
Epoch 485/1000, Loss: 0.1507
Epoch 486/1000, Loss: 0.1504
Epoch 487/1000, Loss: 0.1501
Epoch 488/1000, Loss: 0.1498
Epoch 489/1000, Loss: 0.1495
Epoch 490/1000, Loss: 0.1492
Epoch 491/1000, Loss: 0.1489
Epoch 492/1000, Loss: 0.1486
Epoch 493/1000, Loss: 0.1484
Epoch 494/1000, Loss: 0.1481
Epoch 495/1000, Loss: 0.1478
Epoch 496/1000, Loss: 0.1475
Epoch 497/1000, Loss: 0.1472
Epoch 498/1000, Loss: 0.1469
Epoch 499/1000, Loss: 0.1466
Epoch 500/1000, Loss: 0.1463
Epoch 501/1000, Loss: 0.1461
Epoch 502/1000, Loss: 0.1458
Epoch 503/1000, Loss: 0.1455
Epoch 504/1000, Loss: 0.1452
Epoch 505/1000, Loss: 0.1449
Epoch 506/1000, Loss: 0.1447
Epoch 507/1000, Loss: 0.1444
Epoch 508/1000, Loss: 0.1441
Epoch 509/1000, Loss: 0.1438
Epoch 510/1000, Loss: 0.1436
Epoch 511/1000, Loss: 0.1433
Epoch 512/1000, Loss: 0.1430
Epoch 513/1000, Loss: 0.1428
Epoch 514/1000, Loss: 0.1425
Epoch 515/1000, Loss: 0.1422
Epoch 516/1000, Loss: 0.1420
Epoch 517/1000, Loss: 0.1417
Epoch 518/1000, Loss: 0.1414
Epoch 519/1000, Loss: 0.1412
Epoch 520/1000, Loss: 0.1409
Epoch 521/1000, Loss: 0.1406
Epoch 522/1000, Loss: 0.1404
Epoch 523/1000, Loss: 0.1401
Epoch 524/1000, Loss: 0.1399
Epoch 525/1000, Loss: 0.1396
Epoch 526/1000, Loss: 0.1394
Epoch 527/1000, Loss: 0.1391
Epoch 528/1000, Loss: 0.1389
Epoch 529/1000, Loss: 0.1386
Epoch 530/1000, Loss: 0.1384
Epoch 531/1000, Loss: 0.1381
Epoch 532/1000, Loss: 0.1379
Epoch 533/1000, Loss: 0.1376
Epoch 534/1000, Loss: 0.1374
Epoch 535/1000, Loss: 0.1371
Epoch 536/1000, Loss: 0.1369
Epoch 537/1000, Loss: 0.1366
Epoch 538/1000, Loss: 0.1364
Epoch 539/1000, Loss: 0.1361
Epoch 540/1000, Loss: 0.1359
Epoch 541/1000, Loss: 0.1357
Epoch 542/1000, Loss: 0.1354
Epoch 543/1000, Loss: 0.1352
Epoch 544/1000, Loss: 0.1349
Epoch 545/1000, Loss: 0.1347
Epoch 546/1000, Loss: 0.1345
Epoch 547/1000, Loss: 0.1342
Epoch 548/1000, Loss: 0.1340
Epoch 549/1000, Loss: 0.1337
Epoch 550/1000, Loss: 0.1335
Epoch 551/1000, Loss: 0.1333
Epoch 552/1000, Loss: 0.1330
Epoch 553/1000, Loss: 0.1328
Epoch 554/1000, Loss: 0.1326
Epoch 555/1000, Loss: 0.1324
Epoch 556/1000, Loss: 0.1321
Epoch 557/1000, Loss: 0.1319
Epoch 558/1000, Loss: 0.1317
Epoch 559/1000, Loss: 0.1314
Epoch 560/1000, Loss: 0.1312
Epoch 561/1000, Loss: 0.1310
Epoch 562/1000, Loss: 0.1308
Epoch 563/1000, Loss: 0.1305
Epoch 564/1000, Loss: 0.1303
Epoch 565/1000, Loss: 0.1301
Epoch 566/1000, Loss: 0.1299
Epoch 567/1000, Loss: 0.1296
Epoch 568/1000, Loss: 0.1294
Epoch 569/1000, Loss: 0.1292
Epoch 570/1000, Loss: 0.1290
Epoch 571/1000, Loss: 0.1288
Epoch 572/1000, Loss: 0.1285
Epoch 573/1000, Loss: 0.1283
Epoch 574/1000, Loss: 0.1281
Epoch 575/1000, Loss: 0.1279
Epoch 576/1000, Loss: 0.1277
Epoch 577/1000, Loss: 0.1275
Epoch 578/1000, Loss: 0.1272
Epoch 579/1000, Loss: 0.1270
Epoch 580/1000, Loss: 0.1268
Epoch 581/1000, Loss: 0.1266
Epoch 582/1000, Loss: 0.1264
Epoch 583/1000, Loss: 0.1262
Epoch 584/1000, Loss: 0.1260
Epoch 585/1000, Loss: 0.1258
Epoch 586/1000, Loss: 0.1256
Epoch 587/1000, Loss: 0.1253
Epoch 588/1000, Loss: 0.1251
Epoch 589/1000, Loss: 0.1249
Epoch 590/1000, Loss: 0.1247
Epoch 591/1000, Loss: 0.1245
Epoch 592/1000, Loss: 0.1243
Epoch 593/1000, Loss: 0.1241
Epoch 594/1000, Loss: 0.1239
Epoch 595/1000, Loss: 0.1237
Epoch 596/1000, Loss: 0.1235
Epoch 597/1000, Loss: 0.1233
Epoch 598/1000, Loss: 0.1231
Epoch 599/1000, Loss: 0.1229
Epoch 600/1000, Loss: 0.1227
Epoch 601/1000, Loss: 0.1225
Epoch 602/1000, Loss: 0.1223
Epoch 603/1000, Loss: 0.1221
Epoch 604/1000, Loss: 0.1219
Epoch 605/1000, Loss: 0.1217
Epoch 606/1000, Loss: 0.1215
Epoch 607/1000, Loss: 0.1213
Epoch 608/1000, Loss: 0.1211
Epoch 609/1000, Loss: 0.1209
Epoch 610/1000, Loss: 0.1207
Epoch 611/1000, Loss: 0.1205
Epoch 612/1000, Loss: 0.1203
Epoch 613/1000, Loss: 0.1201
Epoch 614/1000, Loss: 0.1199
Epoch 615/1000, Loss: 0.1197
Epoch 616/1000, Loss: 0.1195
Epoch 617/1000, Loss: 0.1193
Epoch 618/1000, Loss: 0.1191
Epoch 619/1000, Loss: 0.1189
Epoch 620/1000, Loss: 0.1187
Epoch 621/1000, Loss: 0.1185
Epoch 622/1000, Loss: 0.1183
Epoch 623/1000, Loss: 0.1181
Epoch 624/1000, Loss: 0.1179
Epoch 625/1000, Loss: 0.1178
Epoch 626/1000, Loss: 0.1176
Epoch 627/1000, Loss: 0.1174
Epoch 628/1000, Loss: 0.1172
Epoch 629/1000, Loss: 0.1170
Epoch 630/1000, Loss: 0.1168
Epoch 631/1000, Loss: 0.1166
Epoch 632/1000, Loss: 0.1164
Epoch 633/1000, Loss: 0.1162
Epoch 634/1000, Loss: 0.1160
Epoch 635/1000, Loss: 0.1159
Epoch 636/1000, Loss: 0.1157
Epoch 637/1000, Loss: 0.1155
Epoch 638/1000, Loss: 0.1153
Epoch 639/1000, Loss: 0.1151
Epoch 640/1000, Loss: 0.1149
Epoch 641/1000, Loss: 0.1148
Epoch 642/1000, Loss: 0.1146
Epoch 643/1000, Loss: 0.1144
Epoch 644/1000, Loss: 0.1142
Epoch 645/1000, Loss: 0.1140
Epoch 646/1000, Loss: 0.1138
Epoch 647/1000, Loss: 0.1137
Epoch 648/1000, Loss: 0.1135
Epoch 649/1000, Loss: 0.1133
Epoch 650/1000, Loss: 0.1131
Epoch 651/1000, Loss: 0.1129
Epoch 652/1000, Loss: 0.1128
Epoch 653/1000, Loss: 0.1126
Epoch 654/1000, Loss: 0.1124
Epoch 655/1000, Loss: 0.1122
Epoch 656/1000, Loss: 0.1120
Epoch 657/1000, Loss: 0.1119
Epoch 658/1000, Loss: 0.1117
Epoch 659/1000, Loss: 0.1115
Epoch 660/1000, Loss: 0.1113
Epoch 661/1000, Loss: 0.1112
Epoch 662/1000, Loss: 0.1110
Epoch 663/1000, Loss: 0.1108
Epoch 664/1000, Loss: 0.1106
Epoch 665/1000, Loss: 0.1104
Epoch 666/1000, Loss: 0.1103
Epoch 667/1000, Loss: 0.1101
Epoch 668/1000, Loss: 0.1099
Epoch 669/1000, Loss: 0.1097
Epoch 670/1000, Loss: 0.1096
Epoch 671/1000, Loss: 0.1094
Epoch 672/1000, Loss: 0.1092
Epoch 673/1000, Loss: 0.1090
Epoch 674/1000, Loss: 0.1088
Epoch 675/1000, Loss: 0.1087
Epoch 676/1000, Loss: 0.1085
Epoch 677/1000, Loss: 0.1083
Epoch 678/1000, Loss: 0.1081
Epoch 679/1000, Loss: 0.1080
Epoch 680/1000, Loss: 0.1078
Epoch 681/1000, Loss: 0.1076
Epoch 682/1000, Loss: 0.1074
Epoch 683/1000, Loss: 0.1072
Epoch 684/1000, Loss: 0.1071
Epoch 685/1000, Loss: 0.1069
Epoch 686/1000, Loss: 0.1067
Epoch 687/1000, Loss: 0.1065
Epoch 688/1000, Loss: 0.1064
Epoch 689/1000, Loss: 0.1062
Epoch 690/1000, Loss: 0.1060
Epoch 691/1000, Loss: 0.1058
Epoch 692/1000, Loss: 0.1057
Epoch 693/1000, Loss: 0.1055
Epoch 694/1000, Loss: 0.1053
Epoch 695/1000, Loss: 0.1051
Epoch 696/1000, Loss: 0.1049
Epoch 697/1000, Loss: 0.1048
Epoch 698/1000, Loss: 0.1046
Epoch 699/1000, Loss: 0.1044
Epoch 700/1000, Loss: 0.1042
Epoch 701/1000, Loss: 0.1041
Epoch 702/1000, Loss: 0.1039
Epoch 703/1000, Loss: 0.1037
Epoch 704/1000, Loss: 0.1035
Epoch 705/1000, Loss: 0.1033
Epoch 706/1000, Loss: 0.1032
Epoch 707/1000, Loss: 0.1030
Epoch 708/1000, Loss: 0.1028
Epoch 709/1000, Loss: 0.1026
Epoch 710/1000, Loss: 0.1024
Epoch 711/1000, Loss: 0.1023
Epoch 712/1000, Loss: 0.1021
Epoch 713/1000, Loss: 0.1019
Epoch 714/1000, Loss: 0.1017
Epoch 715/1000, Loss: 0.1016
Epoch 716/1000, Loss: 0.1014
Epoch 717/1000, Loss: 0.1012
Epoch 718/1000, Loss: 0.1011
Epoch 719/1000, Loss: 0.1009
Epoch 720/1000, Loss: 0.1007
Epoch 721/1000, Loss: 0.1005
Epoch 722/1000, Loss: 0.1004
Epoch 723/1000, Loss: 0.1002
Epoch 724/1000, Loss: 0.1000
Epoch 725/1000, Loss: 0.0999
Epoch 726/1000, Loss: 0.0997
Epoch 727/1000, Loss: 0.0995
Epoch 728/1000, Loss: 0.0994
Epoch 729/1000, Loss: 0.0992
Epoch 730/1000, Loss: 0.0990
Epoch 731/1000, Loss: 0.0989
Epoch 732/1000, Loss: 0.0987
Epoch 733/1000, Loss: 0.0985
Epoch 734/1000, Loss: 0.0984
Epoch 735/1000, Loss: 0.0982
Epoch 736/1000, Loss: 0.0981
Epoch 737/1000, Loss: 0.0979
Epoch 738/1000, Loss: 0.0977
Epoch 739/1000, Loss: 0.0976
Epoch 740/1000, Loss: 0.0974
Epoch 741/1000, Loss: 0.0973
Epoch 742/1000, Loss: 0.0971
Epoch 743/1000, Loss: 0.0969
Epoch 744/1000, Loss: 0.0968
Epoch 745/1000, Loss: 0.0966
Epoch 746/1000, Loss: 0.0965
Epoch 747/1000, Loss: 0.0963
Epoch 748/1000, Loss: 0.0962
Epoch 749/1000, Loss: 0.0960
Epoch 750/1000, Loss: 0.0959
Epoch 751/1000, Loss: 0.0957
Epoch 752/1000, Loss: 0.0955
Epoch 753/1000, Loss: 0.0954
Epoch 754/1000, Loss: 0.0952
Epoch 755/1000, Loss: 0.0951
Epoch 756/1000, Loss: 0.0949
Epoch 757/1000, Loss: 0.0948
Epoch 758/1000, Loss: 0.0946
Epoch 759/1000, Loss: 0.0945
Epoch 760/1000, Loss: 0.0943
Epoch 761/1000, Loss: 0.0942
Epoch 762/1000, Loss: 0.0940
Epoch 763/1000, Loss: 0.0939
Epoch 764/1000, Loss: 0.0937
Epoch 765/1000, Loss: 0.0936
Epoch 766/1000, Loss: 0.0934
Epoch 767/1000, Loss: 0.0933
Epoch 768/1000, Loss: 0.0931
Epoch 769/1000, Loss: 0.0930
Epoch 770/1000, Loss: 0.0929
Epoch 771/1000, Loss: 0.0927
Epoch 772/1000, Loss: 0.0926
Epoch 773/1000, Loss: 0.0924
Epoch 774/1000, Loss: 0.0923
Epoch 775/1000, Loss: 0.0921
Epoch 776/1000, Loss: 0.0920
Epoch 777/1000, Loss: 0.0918
Epoch 778/1000, Loss: 0.0917
Epoch 779/1000, Loss: 0.0915
Epoch 780/1000, Loss: 0.0914
Epoch 781/1000, Loss: 0.0913
Epoch 782/1000, Loss: 0.0911
Epoch 783/1000, Loss: 0.0910
Epoch 784/1000, Loss: 0.0908
Epoch 785/1000, Loss: 0.0907
Epoch 786/1000, Loss: 0.0905
Epoch 787/1000, Loss: 0.0904
Epoch 788/1000, Loss: 0.0903
Epoch 789/1000, Loss: 0.0901
Epoch 790/1000, Loss: 0.0900
Epoch 791/1000, Loss: 0.0898
Epoch 792/1000, Loss: 0.0897
Epoch 793/1000, Loss: 0.0896
Epoch 794/1000, Loss: 0.0894
Epoch 795/1000, Loss: 0.0893
Epoch 796/1000, Loss: 0.0892
Epoch 797/1000, Loss: 0.0890
Epoch 798/1000, Loss: 0.0889
Epoch 799/1000, Loss: 0.0887
Epoch 800/1000, Loss: 0.0886
Epoch 801/1000, Loss: 0.0885
Epoch 802/1000, Loss: 0.0883
Epoch 803/1000, Loss: 0.0882
Epoch 804/1000, Loss: 0.0881
Epoch 805/1000, Loss: 0.0879
Epoch 806/1000, Loss: 0.0878
Epoch 807/1000, Loss: 0.0877
Epoch 808/1000, Loss: 0.0875
Epoch 809/1000, Loss: 0.0874
Epoch 810/1000, Loss: 0.0873
Epoch 811/1000, Loss: 0.0871
Epoch 812/1000, Loss: 0.0870
Epoch 813/1000, Loss: 0.0869
Epoch 814/1000, Loss: 0.0867
Epoch 815/1000, Loss: 0.0866
Epoch 816/1000, Loss: 0.0865
Epoch 817/1000, Loss: 0.0863
Epoch 818/1000, Loss: 0.0862
Epoch 819/1000, Loss: 0.0861
Epoch 820/1000, Loss: 0.0859
Epoch 821/1000, Loss: 0.0858
Epoch 822/1000, Loss: 0.0857
Epoch 823/1000, Loss: 0.0856
Epoch 824/1000, Loss: 0.0854
Epoch 825/1000, Loss: 0.0853
Epoch 826/1000, Loss: 0.0852
Epoch 827/1000, Loss: 0.0850
Epoch 828/1000, Loss: 0.0849
Epoch 829/1000, Loss: 0.0848
Epoch 830/1000, Loss: 0.0847
Epoch 831/1000, Loss: 0.0845
Epoch 832/1000, Loss: 0.0844
Epoch 833/1000, Loss: 0.0843
Epoch 834/1000, Loss: 0.0842
Epoch 835/1000, Loss: 0.0840
Epoch 836/1000, Loss: 0.0839
Epoch 837/1000, Loss: 0.0838
Epoch 838/1000, Loss: 0.0837
Epoch 839/1000, Loss: 0.0835
Epoch 840/1000, Loss: 0.0834
Epoch 841/1000, Loss: 0.0833
Epoch 842/1000, Loss: 0.0832
Epoch 843/1000, Loss: 0.0830
Epoch 844/1000, Loss: 0.0829
Epoch 845/1000, Loss: 0.0828
Epoch 846/1000, Loss: 0.0827
Epoch 847/1000, Loss: 0.0826
Epoch 848/1000, Loss: 0.0824
Epoch 849/1000, Loss: 0.0823
Epoch 850/1000, Loss: 0.0822
Epoch 851/1000, Loss: 0.0821
Epoch 852/1000, Loss: 0.0819
Epoch 853/1000, Loss: 0.0818
Epoch 854/1000, Loss: 0.0817
Epoch 855/1000, Loss: 0.0816
Epoch 856/1000, Loss: 0.0815
Epoch 857/1000, Loss: 0.0813
Epoch 858/1000, Loss: 0.0812
Epoch 859/1000, Loss: 0.0811
Epoch 860/1000, Loss: 0.0810
Epoch 861/1000, Loss: 0.0809
Epoch 862/1000, Loss: 0.0808
Epoch 863/1000, Loss: 0.0806
Epoch 864/1000, Loss: 0.0805
Epoch 865/1000, Loss: 0.0804
Epoch 866/1000, Loss: 0.0803
Epoch 867/1000, Loss: 0.0802
Epoch 868/1000, Loss: 0.0800
Epoch 869/1000, Loss: 0.0799
Epoch 870/1000, Loss: 0.0798
Epoch 871/1000, Loss: 0.0797
Epoch 872/1000, Loss: 0.0796
Epoch 873/1000, Loss: 0.0795
Epoch 874/1000, Loss: 0.0794
Epoch 875/1000, Loss: 0.0792
Epoch 876/1000, Loss: 0.0791
Epoch 877/1000, Loss: 0.0790
Epoch 878/1000, Loss: 0.0789
Epoch 879/1000, Loss: 0.0788
Epoch 880/1000, Loss: 0.0787
Epoch 881/1000, Loss: 0.0785
Epoch 882/1000, Loss: 0.0784
Epoch 883/1000, Loss: 0.0783
Epoch 884/1000, Loss: 0.0782
Epoch 885/1000, Loss: 0.0781
Epoch 886/1000, Loss: 0.0780
Epoch 887/1000, Loss: 0.0779
Epoch 888/1000, Loss: 0.0778
Epoch 889/1000, Loss: 0.0776
Epoch 890/1000, Loss: 0.0775
Epoch 891/1000, Loss: 0.0774
Epoch 892/1000, Loss: 0.0773
Epoch 893/1000, Loss: 0.0772
Epoch 894/1000, Loss: 0.0771
Epoch 895/1000, Loss: 0.0770
Epoch 896/1000, Loss: 0.0769
Epoch 897/1000, Loss: 0.0768
Epoch 898/1000, Loss: 0.0766
Epoch 899/1000, Loss: 0.0765
Epoch 900/1000, Loss: 0.0764
Epoch 901/1000, Loss: 0.0763
Epoch 902/1000, Loss: 0.0762
Epoch 903/1000, Loss: 0.0761
Epoch 904/1000, Loss: 0.0760
Epoch 905/1000, Loss: 0.0759
Epoch 906/1000, Loss: 0.0758
Epoch 907/1000, Loss: 0.0757
Epoch 908/1000, Loss: 0.0756
Epoch 909/1000, Loss: 0.0754
Epoch 910/1000, Loss: 0.0753
Epoch 911/1000, Loss: 0.0752
Epoch 912/1000, Loss: 0.0751
Epoch 913/1000, Loss: 0.0750
Epoch 914/1000, Loss: 0.0749
Epoch 915/1000, Loss: 0.0748
Epoch 916/1000, Loss: 0.0747
Epoch 917/1000, Loss: 0.0746
Epoch 918/1000, Loss: 0.0745
Epoch 919/1000, Loss: 0.0744
Epoch 920/1000, Loss: 0.0743
Epoch 921/1000, Loss: 0.0742
Epoch 922/1000, Loss: 0.0741
Epoch 923/1000, Loss: 0.0740
Epoch 924/1000, Loss: 0.0739
Epoch 925/1000, Loss: 0.0738
Epoch 926/1000, Loss: 0.0737
Epoch 927/1000, Loss: 0.0735
Epoch 928/1000, Loss: 0.0734
Epoch 929/1000, Loss: 0.0733
Epoch 930/1000, Loss: 0.0732
Epoch 931/1000, Loss: 0.0731
Epoch 932/1000, Loss: 0.0730
Epoch 933/1000, Loss: 0.0729
Epoch 934/1000, Loss: 0.0728
Epoch 935/1000, Loss: 0.0727
Epoch 936/1000, Loss: 0.0726
Epoch 937/1000, Loss: 0.0725
Epoch 938/1000, Loss: 0.0724
Epoch 939/1000, Loss: 0.0723
Epoch 940/1000, Loss: 0.0722
Epoch 941/1000, Loss: 0.0721
Epoch 942/1000, Loss: 0.0720
Epoch 943/1000, Loss: 0.0719
Epoch 944/1000, Loss: 0.0718
Epoch 945/1000, Loss: 0.0717
Epoch 946/1000, Loss: 0.0716
Epoch 947/1000, Loss: 0.0715
Epoch 948/1000, Loss: 0.0714
Epoch 949/1000, Loss: 0.0713
Epoch 950/1000, Loss: 0.0712
Epoch 951/1000, Loss: 0.0711
Epoch 952/1000, Loss: 0.0710
Epoch 953/1000, Loss: 0.0709
Epoch 954/1000, Loss: 0.0708
Epoch 955/1000, Loss: 0.0707
Epoch 956/1000, Loss: 0.0706
Epoch 957/1000, Loss: 0.0705
Epoch 958/1000, Loss: 0.0704
Epoch 959/1000, Loss: 0.0703
Epoch 960/1000, Loss: 0.0702
Epoch 961/1000, Loss: 0.0701
Epoch 962/1000, Loss: 0.0701
Epoch 963/1000, Loss: 0.0700
Epoch 964/1000, Loss: 0.0699
Epoch 965/1000, Loss: 0.0698
Epoch 966/1000, Loss: 0.0697
Epoch 967/1000, Loss: 0.0696
Epoch 968/1000, Loss: 0.0695
Epoch 969/1000, Loss: 0.0694
Epoch 970/1000, Loss: 0.0693
Epoch 971/1000, Loss: 0.0692
Epoch 972/1000, Loss: 0.0691
Epoch 973/1000, Loss: 0.0690
Epoch 974/1000, Loss: 0.0689
Epoch 975/1000, Loss: 0.0688
Epoch 976/1000, Loss: 0.0687
Epoch 977/1000, Loss: 0.0686
Epoch 978/1000, Loss: 0.0685
Epoch 979/1000, Loss: 0.0685
Epoch 980/1000, Loss: 0.0684
Epoch 981/1000, Loss: 0.0683
Epoch 982/1000, Loss: 0.0682
Epoch 983/1000, Loss: 0.0681
Epoch 984/1000, Loss: 0.0680
Epoch 985/1000, Loss: 0.0679
Epoch 986/1000, Loss: 0.0678
Epoch 987/1000, Loss: 0.0677
Epoch 988/1000, Loss: 0.0676
Epoch 989/1000, Loss: 0.0675
Epoch 990/1000, Loss: 0.0674
Epoch 991/1000, Loss: 0.0674
Epoch 992/1000, Loss: 0.0673
Epoch 993/1000, Loss: 0.0672
Epoch 994/1000, Loss: 0.0671
Epoch 995/1000, Loss: 0.0670
Epoch 996/1000, Loss: 0.0669
Epoch 997/1000, Loss: 0.0668
Epoch 998/1000, Loss: 0.0667
Epoch 999/1000, Loss: 0.0666
Epoch 1000/1000, Loss: 0.0665

MLP on MNIST with ReLU
Accuracy: 0.9621, F1-score: 0.9618
No description has been provided for this image
No description has been provided for this image
Test Loss: 0.1322

Using Cross Entropy and Sigmoid

InΒ [14]:
mlp_sigmoid = MLP_sigmoid().to(device)
print("\nTraining MLP...")
train_losses = train_mlp(mlp_sigmoid, train_X.to(device), train_y.to(device), epochs=1000)

preds, acc, f1, cm, test_loss = evaluate_model(mlp_sigmoid, test_X, test_y, device)
summary("MLP on MNIST with Sigmoid", acc, f1, cm, train_losses)
print(f"Test Loss: {test_loss:.4f}")
Training MLP...
Epoch 1/1000, Loss: 2.3312
Epoch 2/1000, Loss: 2.3266
Epoch 3/1000, Loss: 2.3222
Epoch 4/1000, Loss: 2.3181
Epoch 5/1000, Loss: 2.3141
Epoch 6/1000, Loss: 2.3103
Epoch 7/1000, Loss: 2.3066
Epoch 8/1000, Loss: 2.3031
Epoch 9/1000, Loss: 2.2998
Epoch 10/1000, Loss: 2.2967
Epoch 11/1000, Loss: 2.2936
Epoch 12/1000, Loss: 2.2907
Epoch 13/1000, Loss: 2.2879
Epoch 14/1000, Loss: 2.2852
Epoch 15/1000, Loss: 2.2826
Epoch 16/1000, Loss: 2.2800
Epoch 17/1000, Loss: 2.2775
Epoch 18/1000, Loss: 2.2750
Epoch 19/1000, Loss: 2.2726
Epoch 20/1000, Loss: 2.2701
Epoch 21/1000, Loss: 2.2677
Epoch 22/1000, Loss: 2.2652
Epoch 23/1000, Loss: 2.2628
Epoch 24/1000, Loss: 2.2603
Epoch 25/1000, Loss: 2.2578
Epoch 26/1000, Loss: 2.2553
Epoch 27/1000, Loss: 2.2528
Epoch 28/1000, Loss: 2.2503
Epoch 29/1000, Loss: 2.2477
Epoch 30/1000, Loss: 2.2450
Epoch 31/1000, Loss: 2.2424
Epoch 32/1000, Loss: 2.2397
Epoch 33/1000, Loss: 2.2370
Epoch 34/1000, Loss: 2.2342
Epoch 35/1000, Loss: 2.2314
Epoch 36/1000, Loss: 2.2285
Epoch 37/1000, Loss: 2.2256
Epoch 38/1000, Loss: 2.2226
Epoch 39/1000, Loss: 2.2196
Epoch 40/1000, Loss: 2.2165
Epoch 41/1000, Loss: 2.2134
Epoch 42/1000, Loss: 2.2102
Epoch 43/1000, Loss: 2.2069
Epoch 44/1000, Loss: 2.2035
Epoch 45/1000, Loss: 2.2001
Epoch 46/1000, Loss: 2.1967
Epoch 47/1000, Loss: 2.1931
Epoch 48/1000, Loss: 2.1895
Epoch 49/1000, Loss: 2.1858
Epoch 50/1000, Loss: 2.1821
Epoch 51/1000, Loss: 2.1783
Epoch 52/1000, Loss: 2.1744
Epoch 53/1000, Loss: 2.1705
Epoch 54/1000, Loss: 2.1665
Epoch 55/1000, Loss: 2.1624
Epoch 56/1000, Loss: 2.1583
Epoch 57/1000, Loss: 2.1541
Epoch 58/1000, Loss: 2.1498
Epoch 59/1000, Loss: 2.1455
Epoch 60/1000, Loss: 2.1411
Epoch 61/1000, Loss: 2.1366
Epoch 62/1000, Loss: 2.1321
Epoch 63/1000, Loss: 2.1275
Epoch 64/1000, Loss: 2.1228
Epoch 65/1000, Loss: 2.1181
Epoch 66/1000, Loss: 2.1133
Epoch 67/1000, Loss: 2.1085
Epoch 68/1000, Loss: 2.1036
Epoch 69/1000, Loss: 2.0986
Epoch 70/1000, Loss: 2.0936
Epoch 71/1000, Loss: 2.0885
Epoch 72/1000, Loss: 2.0834
Epoch 73/1000, Loss: 2.0782
Epoch 74/1000, Loss: 2.0729
Epoch 75/1000, Loss: 2.0676
Epoch 76/1000, Loss: 2.0623
Epoch 77/1000, Loss: 2.0569
Epoch 78/1000, Loss: 2.0514
Epoch 79/1000, Loss: 2.0459
Epoch 80/1000, Loss: 2.0404
Epoch 81/1000, Loss: 2.0348
Epoch 82/1000, Loss: 2.0291
Epoch 83/1000, Loss: 2.0234
Epoch 84/1000, Loss: 2.0177
Epoch 85/1000, Loss: 2.0119
Epoch 86/1000, Loss: 2.0061
Epoch 87/1000, Loss: 2.0003
Epoch 88/1000, Loss: 1.9944
Epoch 89/1000, Loss: 1.9885
Epoch 90/1000, Loss: 1.9826
Epoch 91/1000, Loss: 1.9766
Epoch 92/1000, Loss: 1.9706
Epoch 93/1000, Loss: 1.9646
Epoch 94/1000, Loss: 1.9586
Epoch 95/1000, Loss: 1.9525
Epoch 96/1000, Loss: 1.9465
Epoch 97/1000, Loss: 1.9404
Epoch 98/1000, Loss: 1.9343
Epoch 99/1000, Loss: 1.9282
Epoch 100/1000, Loss: 1.9220
Epoch 101/1000, Loss: 1.9159
Epoch 102/1000, Loss: 1.9097
Epoch 103/1000, Loss: 1.9036
Epoch 104/1000, Loss: 1.8974
Epoch 105/1000, Loss: 1.8913
Epoch 106/1000, Loss: 1.8851
Epoch 107/1000, Loss: 1.8789
Epoch 108/1000, Loss: 1.8728
Epoch 109/1000, Loss: 1.8666
Epoch 110/1000, Loss: 1.8605
Epoch 111/1000, Loss: 1.8543
Epoch 112/1000, Loss: 1.8482
Epoch 113/1000, Loss: 1.8420
Epoch 114/1000, Loss: 1.8359
Epoch 115/1000, Loss: 1.8298
Epoch 116/1000, Loss: 1.8237
Epoch 117/1000, Loss: 1.8176
Epoch 118/1000, Loss: 1.8115
Epoch 119/1000, Loss: 1.8054
Epoch 120/1000, Loss: 1.7994
Epoch 121/1000, Loss: 1.7933
Epoch 122/1000, Loss: 1.7873
Epoch 123/1000, Loss: 1.7813
Epoch 124/1000, Loss: 1.7753
Epoch 125/1000, Loss: 1.7693
Epoch 126/1000, Loss: 1.7634
Epoch 127/1000, Loss: 1.7574
Epoch 128/1000, Loss: 1.7515
Epoch 129/1000, Loss: 1.7456
Epoch 130/1000, Loss: 1.7397
Epoch 131/1000, Loss: 1.7339
Epoch 132/1000, Loss: 1.7280
Epoch 133/1000, Loss: 1.7222
Epoch 134/1000, Loss: 1.7164
Epoch 135/1000, Loss: 1.7107
Epoch 136/1000, Loss: 1.7049
Epoch 137/1000, Loss: 1.6992
Epoch 138/1000, Loss: 1.6935
Epoch 139/1000, Loss: 1.6878
Epoch 140/1000, Loss: 1.6822
Epoch 141/1000, Loss: 1.6765
Epoch 142/1000, Loss: 1.6709
Epoch 143/1000, Loss: 1.6654
Epoch 144/1000, Loss: 1.6598
Epoch 145/1000, Loss: 1.6543
Epoch 146/1000, Loss: 1.6488
Epoch 147/1000, Loss: 1.6433
Epoch 148/1000, Loss: 1.6379
Epoch 149/1000, Loss: 1.6324
Epoch 150/1000, Loss: 1.6271
Epoch 151/1000, Loss: 1.6217
Epoch 152/1000, Loss: 1.6163
Epoch 153/1000, Loss: 1.6110
Epoch 154/1000, Loss: 1.6057
Epoch 155/1000, Loss: 1.6005
Epoch 156/1000, Loss: 1.5952
Epoch 157/1000, Loss: 1.5900
Epoch 158/1000, Loss: 1.5849
Epoch 159/1000, Loss: 1.5797
Epoch 160/1000, Loss: 1.5746
Epoch 161/1000, Loss: 1.5695
Epoch 162/1000, Loss: 1.5644
Epoch 163/1000, Loss: 1.5593
Epoch 164/1000, Loss: 1.5543
Epoch 165/1000, Loss: 1.5493
Epoch 166/1000, Loss: 1.5443
Epoch 167/1000, Loss: 1.5394
Epoch 168/1000, Loss: 1.5345
Epoch 169/1000, Loss: 1.5296
Epoch 170/1000, Loss: 1.5247
Epoch 171/1000, Loss: 1.5199
Epoch 172/1000, Loss: 1.5150
Epoch 173/1000, Loss: 1.5102
Epoch 174/1000, Loss: 1.5055
Epoch 175/1000, Loss: 1.5007
Epoch 176/1000, Loss: 1.4960
Epoch 177/1000, Loss: 1.4913
Epoch 178/1000, Loss: 1.4866
Epoch 179/1000, Loss: 1.4819
Epoch 180/1000, Loss: 1.4773
Epoch 181/1000, Loss: 1.4727
Epoch 182/1000, Loss: 1.4681
Epoch 183/1000, Loss: 1.4636
Epoch 184/1000, Loss: 1.4590
Epoch 185/1000, Loss: 1.4545
Epoch 186/1000, Loss: 1.4500
Epoch 187/1000, Loss: 1.4455
Epoch 188/1000, Loss: 1.4411
Epoch 189/1000, Loss: 1.4366
Epoch 190/1000, Loss: 1.4322
Epoch 191/1000, Loss: 1.4278
Epoch 192/1000, Loss: 1.4235
Epoch 193/1000, Loss: 1.4191
Epoch 194/1000, Loss: 1.4148
Epoch 195/1000, Loss: 1.4105
Epoch 196/1000, Loss: 1.4062
Epoch 197/1000, Loss: 1.4019
Epoch 198/1000, Loss: 1.3977
Epoch 199/1000, Loss: 1.3935
Epoch 200/1000, Loss: 1.3893
Epoch 201/1000, Loss: 1.3851
Epoch 202/1000, Loss: 1.3809
Epoch 203/1000, Loss: 1.3768
Epoch 204/1000, Loss: 1.3726
Epoch 205/1000, Loss: 1.3685
Epoch 206/1000, Loss: 1.3644
Epoch 207/1000, Loss: 1.3604
Epoch 208/1000, Loss: 1.3563
Epoch 209/1000, Loss: 1.3523
Epoch 210/1000, Loss: 1.3483
Epoch 211/1000, Loss: 1.3443
Epoch 212/1000, Loss: 1.3403
Epoch 213/1000, Loss: 1.3363
Epoch 214/1000, Loss: 1.3324
Epoch 215/1000, Loss: 1.3284
Epoch 216/1000, Loss: 1.3245
Epoch 217/1000, Loss: 1.3206
Epoch 218/1000, Loss: 1.3168
Epoch 219/1000, Loss: 1.3129
Epoch 220/1000, Loss: 1.3091
Epoch 221/1000, Loss: 1.3052
Epoch 222/1000, Loss: 1.3014
Epoch 223/1000, Loss: 1.2976
Epoch 224/1000, Loss: 1.2939
Epoch 225/1000, Loss: 1.2901
Epoch 226/1000, Loss: 1.2863
Epoch 227/1000, Loss: 1.2826
Epoch 228/1000, Loss: 1.2789
Epoch 229/1000, Loss: 1.2752
Epoch 230/1000, Loss: 1.2715
Epoch 231/1000, Loss: 1.2679
Epoch 232/1000, Loss: 1.2642
Epoch 233/1000, Loss: 1.2606
Epoch 234/1000, Loss: 1.2569
Epoch 235/1000, Loss: 1.2533
Epoch 236/1000, Loss: 1.2497
Epoch 237/1000, Loss: 1.2462
Epoch 238/1000, Loss: 1.2426
Epoch 239/1000, Loss: 1.2391
Epoch 240/1000, Loss: 1.2355
Epoch 241/1000, Loss: 1.2320
Epoch 242/1000, Loss: 1.2285
Epoch 243/1000, Loss: 1.2250
Epoch 244/1000, Loss: 1.2215
Epoch 245/1000, Loss: 1.2180
Epoch 246/1000, Loss: 1.2146
Epoch 247/1000, Loss: 1.2111
Epoch 248/1000, Loss: 1.2077
Epoch 249/1000, Loss: 1.2043
Epoch 250/1000, Loss: 1.2009
Epoch 251/1000, Loss: 1.1975
Epoch 252/1000, Loss: 1.1941
Epoch 253/1000, Loss: 1.1907
Epoch 254/1000, Loss: 1.1874
Epoch 255/1000, Loss: 1.1840
Epoch 256/1000, Loss: 1.1807
Epoch 257/1000, Loss: 1.1773
Epoch 258/1000, Loss: 1.1740
Epoch 259/1000, Loss: 1.1707
Epoch 260/1000, Loss: 1.1674
Epoch 261/1000, Loss: 1.1641
Epoch 262/1000, Loss: 1.1609
Epoch 263/1000, Loss: 1.1576
Epoch 264/1000, Loss: 1.1544
Epoch 265/1000, Loss: 1.1511
Epoch 266/1000, Loss: 1.1479
Epoch 267/1000, Loss: 1.1447
Epoch 268/1000, Loss: 1.1415
Epoch 269/1000, Loss: 1.1383
Epoch 270/1000, Loss: 1.1351
Epoch 271/1000, Loss: 1.1319
Epoch 272/1000, Loss: 1.1287
Epoch 273/1000, Loss: 1.1256
Epoch 274/1000, Loss: 1.1224
Epoch 275/1000, Loss: 1.1193
Epoch 276/1000, Loss: 1.1162
Epoch 277/1000, Loss: 1.1131
Epoch 278/1000, Loss: 1.1100
Epoch 279/1000, Loss: 1.1069
Epoch 280/1000, Loss: 1.1038
Epoch 281/1000, Loss: 1.1007
Epoch 282/1000, Loss: 1.0977
Epoch 283/1000, Loss: 1.0946
Epoch 284/1000, Loss: 1.0916
Epoch 285/1000, Loss: 1.0885
Epoch 286/1000, Loss: 1.0855
Epoch 287/1000, Loss: 1.0825
Epoch 288/1000, Loss: 1.0795
Epoch 289/1000, Loss: 1.0765
Epoch 290/1000, Loss: 1.0735
Epoch 291/1000, Loss: 1.0706
Epoch 292/1000, Loss: 1.0676
Epoch 293/1000, Loss: 1.0647
Epoch 294/1000, Loss: 1.0617
Epoch 295/1000, Loss: 1.0588
Epoch 296/1000, Loss: 1.0559
Epoch 297/1000, Loss: 1.0530
Epoch 298/1000, Loss: 1.0501
Epoch 299/1000, Loss: 1.0472
Epoch 300/1000, Loss: 1.0443
Epoch 301/1000, Loss: 1.0414
Epoch 302/1000, Loss: 1.0386
Epoch 303/1000, Loss: 1.0357
Epoch 304/1000, Loss: 1.0329
Epoch 305/1000, Loss: 1.0300
Epoch 306/1000, Loss: 1.0272
Epoch 307/1000, Loss: 1.0244
Epoch 308/1000, Loss: 1.0216
Epoch 309/1000, Loss: 1.0188
Epoch 310/1000, Loss: 1.0160
Epoch 311/1000, Loss: 1.0132
Epoch 312/1000, Loss: 1.0105
Epoch 313/1000, Loss: 1.0077
Epoch 314/1000, Loss: 1.0050
Epoch 315/1000, Loss: 1.0022
Epoch 316/1000, Loss: 0.9995
Epoch 317/1000, Loss: 0.9968
Epoch 318/1000, Loss: 0.9941
Epoch 319/1000, Loss: 0.9914
Epoch 320/1000, Loss: 0.9887
Epoch 321/1000, Loss: 0.9860
Epoch 322/1000, Loss: 0.9833
Epoch 323/1000, Loss: 0.9806
Epoch 324/1000, Loss: 0.9780
Epoch 325/1000, Loss: 0.9753
Epoch 326/1000, Loss: 0.9727
Epoch 327/1000, Loss: 0.9700
Epoch 328/1000, Loss: 0.9674
Epoch 329/1000, Loss: 0.9648
Epoch 330/1000, Loss: 0.9622
Epoch 331/1000, Loss: 0.9596
Epoch 332/1000, Loss: 0.9570
Epoch 333/1000, Loss: 0.9544
Epoch 334/1000, Loss: 0.9518
Epoch 335/1000, Loss: 0.9493
Epoch 336/1000, Loss: 0.9467
Epoch 337/1000, Loss: 0.9442
Epoch 338/1000, Loss: 0.9416
Epoch 339/1000, Loss: 0.9391
Epoch 340/1000, Loss: 0.9366
Epoch 341/1000, Loss: 0.9340
Epoch 342/1000, Loss: 0.9315
Epoch 343/1000, Loss: 0.9290
Epoch 344/1000, Loss: 0.9265
Epoch 345/1000, Loss: 0.9240
Epoch 346/1000, Loss: 0.9215
Epoch 347/1000, Loss: 0.9191
Epoch 348/1000, Loss: 0.9166
Epoch 349/1000, Loss: 0.9141
Epoch 350/1000, Loss: 0.9117
Epoch 351/1000, Loss: 0.9093
Epoch 352/1000, Loss: 0.9068
Epoch 353/1000, Loss: 0.9044
Epoch 354/1000, Loss: 0.9020
Epoch 355/1000, Loss: 0.8995
Epoch 356/1000, Loss: 0.8971
Epoch 357/1000, Loss: 0.8947
Epoch 358/1000, Loss: 0.8923
Epoch 359/1000, Loss: 0.8900
Epoch 360/1000, Loss: 0.8876
Epoch 361/1000, Loss: 0.8852
Epoch 362/1000, Loss: 0.8828
Epoch 363/1000, Loss: 0.8805
Epoch 364/1000, Loss: 0.8781
Epoch 365/1000, Loss: 0.8758
Epoch 366/1000, Loss: 0.8735
Epoch 367/1000, Loss: 0.8711
Epoch 368/1000, Loss: 0.8688
Epoch 369/1000, Loss: 0.8665
Epoch 370/1000, Loss: 0.8642
Epoch 371/1000, Loss: 0.8619
Epoch 372/1000, Loss: 0.8596
Epoch 373/1000, Loss: 0.8573
Epoch 374/1000, Loss: 0.8550
Epoch 375/1000, Loss: 0.8527
Epoch 376/1000, Loss: 0.8505
Epoch 377/1000, Loss: 0.8482
Epoch 378/1000, Loss: 0.8460
Epoch 379/1000, Loss: 0.8437
Epoch 380/1000, Loss: 0.8415
Epoch 381/1000, Loss: 0.8392
Epoch 382/1000, Loss: 0.8370
Epoch 383/1000, Loss: 0.8348
Epoch 384/1000, Loss: 0.8326
Epoch 385/1000, Loss: 0.8304
Epoch 386/1000, Loss: 0.8282
Epoch 387/1000, Loss: 0.8260
Epoch 388/1000, Loss: 0.8238
Epoch 389/1000, Loss: 0.8216
Epoch 390/1000, Loss: 0.8195
Epoch 391/1000, Loss: 0.8173
Epoch 392/1000, Loss: 0.8151
Epoch 393/1000, Loss: 0.8130
Epoch 394/1000, Loss: 0.8109
Epoch 395/1000, Loss: 0.8087
Epoch 396/1000, Loss: 0.8066
Epoch 397/1000, Loss: 0.8045
Epoch 398/1000, Loss: 0.8024
Epoch 399/1000, Loss: 0.8003
Epoch 400/1000, Loss: 0.7982
Epoch 401/1000, Loss: 0.7961
Epoch 402/1000, Loss: 0.7940
Epoch 403/1000, Loss: 0.7919
Epoch 404/1000, Loss: 0.7899
Epoch 405/1000, Loss: 0.7878
Epoch 406/1000, Loss: 0.7858
Epoch 407/1000, Loss: 0.7837
Epoch 408/1000, Loss: 0.7817
Epoch 409/1000, Loss: 0.7796
Epoch 410/1000, Loss: 0.7776
Epoch 411/1000, Loss: 0.7756
Epoch 412/1000, Loss: 0.7736
Epoch 413/1000, Loss: 0.7716
Epoch 414/1000, Loss: 0.7696
Epoch 415/1000, Loss: 0.7676
Epoch 416/1000, Loss: 0.7656
Epoch 417/1000, Loss: 0.7637
Epoch 418/1000, Loss: 0.7617
Epoch 419/1000, Loss: 0.7597
Epoch 420/1000, Loss: 0.7578
Epoch 421/1000, Loss: 0.7558
Epoch 422/1000, Loss: 0.7539
Epoch 423/1000, Loss: 0.7520
Epoch 424/1000, Loss: 0.7501
Epoch 425/1000, Loss: 0.7481
Epoch 426/1000, Loss: 0.7462
Epoch 427/1000, Loss: 0.7443
Epoch 428/1000, Loss: 0.7424
Epoch 429/1000, Loss: 0.7405
Epoch 430/1000, Loss: 0.7387
Epoch 431/1000, Loss: 0.7368
Epoch 432/1000, Loss: 0.7349
Epoch 433/1000, Loss: 0.7331
Epoch 434/1000, Loss: 0.7312
Epoch 435/1000, Loss: 0.7294
Epoch 436/1000, Loss: 0.7275
Epoch 437/1000, Loss: 0.7257
Epoch 438/1000, Loss: 0.7239
Epoch 439/1000, Loss: 0.7220
Epoch 440/1000, Loss: 0.7202
Epoch 441/1000, Loss: 0.7184
Epoch 442/1000, Loss: 0.7166
Epoch 443/1000, Loss: 0.7148
Epoch 444/1000, Loss: 0.7131
Epoch 445/1000, Loss: 0.7113
Epoch 446/1000, Loss: 0.7095
Epoch 447/1000, Loss: 0.7077
Epoch 448/1000, Loss: 0.7060
Epoch 449/1000, Loss: 0.7042
Epoch 450/1000, Loss: 0.7025
Epoch 451/1000, Loss: 0.7007
Epoch 452/1000, Loss: 0.6990
Epoch 453/1000, Loss: 0.6973
Epoch 454/1000, Loss: 0.6956
Epoch 455/1000, Loss: 0.6938
Epoch 456/1000, Loss: 0.6921
Epoch 457/1000, Loss: 0.6904
Epoch 458/1000, Loss: 0.6887
Epoch 459/1000, Loss: 0.6870
Epoch 460/1000, Loss: 0.6854
Epoch 461/1000, Loss: 0.6837
Epoch 462/1000, Loss: 0.6820
Epoch 463/1000, Loss: 0.6804
Epoch 464/1000, Loss: 0.6787
Epoch 465/1000, Loss: 0.6770
Epoch 466/1000, Loss: 0.6754
Epoch 467/1000, Loss: 0.6738
Epoch 468/1000, Loss: 0.6721
Epoch 469/1000, Loss: 0.6705
Epoch 470/1000, Loss: 0.6689
Epoch 471/1000, Loss: 0.6672
Epoch 472/1000, Loss: 0.6656
Epoch 473/1000, Loss: 0.6640
Epoch 474/1000, Loss: 0.6624
Epoch 475/1000, Loss: 0.6608
Epoch 476/1000, Loss: 0.6592
Epoch 477/1000, Loss: 0.6577
Epoch 478/1000, Loss: 0.6561
Epoch 479/1000, Loss: 0.6545
Epoch 480/1000, Loss: 0.6529
Epoch 481/1000, Loss: 0.6514
Epoch 482/1000, Loss: 0.6498
Epoch 483/1000, Loss: 0.6483
Epoch 484/1000, Loss: 0.6467
Epoch 485/1000, Loss: 0.6452
Epoch 486/1000, Loss: 0.6437
Epoch 487/1000, Loss: 0.6421
Epoch 488/1000, Loss: 0.6406
Epoch 489/1000, Loss: 0.6391
Epoch 490/1000, Loss: 0.6376
Epoch 491/1000, Loss: 0.6361
Epoch 492/1000, Loss: 0.6346
Epoch 493/1000, Loss: 0.6331
Epoch 494/1000, Loss: 0.6316
Epoch 495/1000, Loss: 0.6301
Epoch 496/1000, Loss: 0.6286
Epoch 497/1000, Loss: 0.6271
Epoch 498/1000, Loss: 0.6257
Epoch 499/1000, Loss: 0.6242
Epoch 500/1000, Loss: 0.6227
Epoch 501/1000, Loss: 0.6213
Epoch 502/1000, Loss: 0.6198
Epoch 503/1000, Loss: 0.6184
Epoch 504/1000, Loss: 0.6170
Epoch 505/1000, Loss: 0.6155
Epoch 506/1000, Loss: 0.6141
Epoch 507/1000, Loss: 0.6127
Epoch 508/1000, Loss: 0.6112
Epoch 509/1000, Loss: 0.6098
Epoch 510/1000, Loss: 0.6084
Epoch 511/1000, Loss: 0.6070
Epoch 512/1000, Loss: 0.6056
Epoch 513/1000, Loss: 0.6042
Epoch 514/1000, Loss: 0.6028
Epoch 515/1000, Loss: 0.6014
Epoch 516/1000, Loss: 0.6000
Epoch 517/1000, Loss: 0.5987
Epoch 518/1000, Loss: 0.5973
Epoch 519/1000, Loss: 0.5959
Epoch 520/1000, Loss: 0.5945
Epoch 521/1000, Loss: 0.5932
Epoch 522/1000, Loss: 0.5918
Epoch 523/1000, Loss: 0.5905
Epoch 524/1000, Loss: 0.5891
Epoch 525/1000, Loss: 0.5878
Epoch 526/1000, Loss: 0.5864
Epoch 527/1000, Loss: 0.5851
Epoch 528/1000, Loss: 0.5838
Epoch 529/1000, Loss: 0.5825
Epoch 530/1000, Loss: 0.5811
Epoch 531/1000, Loss: 0.5798
Epoch 532/1000, Loss: 0.5785
Epoch 533/1000, Loss: 0.5772
Epoch 534/1000, Loss: 0.5759
Epoch 535/1000, Loss: 0.5746
Epoch 536/1000, Loss: 0.5733
Epoch 537/1000, Loss: 0.5720
Epoch 538/1000, Loss: 0.5707
Epoch 539/1000, Loss: 0.5694
Epoch 540/1000, Loss: 0.5681
Epoch 541/1000, Loss: 0.5669
Epoch 542/1000, Loss: 0.5656
Epoch 543/1000, Loss: 0.5643
Epoch 544/1000, Loss: 0.5631
Epoch 545/1000, Loss: 0.5618
Epoch 546/1000, Loss: 0.5605
Epoch 547/1000, Loss: 0.5593
Epoch 548/1000, Loss: 0.5580
Epoch 549/1000, Loss: 0.5568
Epoch 550/1000, Loss: 0.5556
Epoch 551/1000, Loss: 0.5543
Epoch 552/1000, Loss: 0.5531
Epoch 553/1000, Loss: 0.5519
Epoch 554/1000, Loss: 0.5506
Epoch 555/1000, Loss: 0.5494
Epoch 556/1000, Loss: 0.5482
Epoch 557/1000, Loss: 0.5470
Epoch 558/1000, Loss: 0.5458
Epoch 559/1000, Loss: 0.5446
Epoch 560/1000, Loss: 0.5434
Epoch 561/1000, Loss: 0.5422
Epoch 562/1000, Loss: 0.5410
Epoch 563/1000, Loss: 0.5398
Epoch 564/1000, Loss: 0.5386
Epoch 565/1000, Loss: 0.5374
Epoch 566/1000, Loss: 0.5363
Epoch 567/1000, Loss: 0.5351
Epoch 568/1000, Loss: 0.5339
Epoch 569/1000, Loss: 0.5327
Epoch 570/1000, Loss: 0.5316
Epoch 571/1000, Loss: 0.5304
Epoch 572/1000, Loss: 0.5293
Epoch 573/1000, Loss: 0.5281
Epoch 574/1000, Loss: 0.5270
Epoch 575/1000, Loss: 0.5258
Epoch 576/1000, Loss: 0.5247
Epoch 577/1000, Loss: 0.5235
Epoch 578/1000, Loss: 0.5224
Epoch 579/1000, Loss: 0.5213
Epoch 580/1000, Loss: 0.5201
Epoch 581/1000, Loss: 0.5190
Epoch 582/1000, Loss: 0.5179
Epoch 583/1000, Loss: 0.5168
Epoch 584/1000, Loss: 0.5157
Epoch 585/1000, Loss: 0.5146
Epoch 586/1000, Loss: 0.5134
Epoch 587/1000, Loss: 0.5123
Epoch 588/1000, Loss: 0.5112
Epoch 589/1000, Loss: 0.5101
Epoch 590/1000, Loss: 0.5091
Epoch 591/1000, Loss: 0.5080
Epoch 592/1000, Loss: 0.5069
Epoch 593/1000, Loss: 0.5058
Epoch 594/1000, Loss: 0.5047
Epoch 595/1000, Loss: 0.5036
Epoch 596/1000, Loss: 0.5026
Epoch 597/1000, Loss: 0.5015
Epoch 598/1000, Loss: 0.5004
Epoch 599/1000, Loss: 0.4994
Epoch 600/1000, Loss: 0.4983
Epoch 601/1000, Loss: 0.4973
Epoch 602/1000, Loss: 0.4962
Epoch 603/1000, Loss: 0.4952
Epoch 604/1000, Loss: 0.4941
Epoch 605/1000, Loss: 0.4931
Epoch 606/1000, Loss: 0.4920
Epoch 607/1000, Loss: 0.4910
Epoch 608/1000, Loss: 0.4900
Epoch 609/1000, Loss: 0.4890
Epoch 610/1000, Loss: 0.4879
Epoch 611/1000, Loss: 0.4869
Epoch 612/1000, Loss: 0.4859
Epoch 613/1000, Loss: 0.4849
Epoch 614/1000, Loss: 0.4839
Epoch 615/1000, Loss: 0.4829
Epoch 616/1000, Loss: 0.4819
Epoch 617/1000, Loss: 0.4809
Epoch 618/1000, Loss: 0.4799
Epoch 619/1000, Loss: 0.4789
Epoch 620/1000, Loss: 0.4779
Epoch 621/1000, Loss: 0.4769
Epoch 622/1000, Loss: 0.4759
Epoch 623/1000, Loss: 0.4749
Epoch 624/1000, Loss: 0.4740
Epoch 625/1000, Loss: 0.4730
Epoch 626/1000, Loss: 0.4720
Epoch 627/1000, Loss: 0.4710
Epoch 628/1000, Loss: 0.4701
Epoch 629/1000, Loss: 0.4691
Epoch 630/1000, Loss: 0.4682
Epoch 631/1000, Loss: 0.4672
Epoch 632/1000, Loss: 0.4662
Epoch 633/1000, Loss: 0.4653
Epoch 634/1000, Loss: 0.4644
Epoch 635/1000, Loss: 0.4634
Epoch 636/1000, Loss: 0.4625
Epoch 637/1000, Loss: 0.4615
Epoch 638/1000, Loss: 0.4606
Epoch 639/1000, Loss: 0.4597
Epoch 640/1000, Loss: 0.4587
Epoch 641/1000, Loss: 0.4578
Epoch 642/1000, Loss: 0.4569
Epoch 643/1000, Loss: 0.4560
Epoch 644/1000, Loss: 0.4551
Epoch 645/1000, Loss: 0.4542
Epoch 646/1000, Loss: 0.4533
Epoch 647/1000, Loss: 0.4523
Epoch 648/1000, Loss: 0.4514
Epoch 649/1000, Loss: 0.4505
Epoch 650/1000, Loss: 0.4496
Epoch 651/1000, Loss: 0.4488
Epoch 652/1000, Loss: 0.4479
Epoch 653/1000, Loss: 0.4470
Epoch 654/1000, Loss: 0.4461
Epoch 655/1000, Loss: 0.4452
Epoch 656/1000, Loss: 0.4443
Epoch 657/1000, Loss: 0.4435
Epoch 658/1000, Loss: 0.4426
Epoch 659/1000, Loss: 0.4417
Epoch 660/1000, Loss: 0.4409
Epoch 661/1000, Loss: 0.4400
Epoch 662/1000, Loss: 0.4391
Epoch 663/1000, Loss: 0.4383
Epoch 664/1000, Loss: 0.4374
Epoch 665/1000, Loss: 0.4366
Epoch 666/1000, Loss: 0.4357
Epoch 667/1000, Loss: 0.4349
Epoch 668/1000, Loss: 0.4340
Epoch 669/1000, Loss: 0.4332
Epoch 670/1000, Loss: 0.4323
Epoch 671/1000, Loss: 0.4315
Epoch 672/1000, Loss: 0.4307
Epoch 673/1000, Loss: 0.4298
Epoch 674/1000, Loss: 0.4290
Epoch 675/1000, Loss: 0.4282
Epoch 676/1000, Loss: 0.4274
Epoch 677/1000, Loss: 0.4266
Epoch 678/1000, Loss: 0.4257
Epoch 679/1000, Loss: 0.4249
Epoch 680/1000, Loss: 0.4241
Epoch 681/1000, Loss: 0.4233
Epoch 682/1000, Loss: 0.4225
Epoch 683/1000, Loss: 0.4217
Epoch 684/1000, Loss: 0.4209
Epoch 685/1000, Loss: 0.4201
Epoch 686/1000, Loss: 0.4193
Epoch 687/1000, Loss: 0.4185
Epoch 688/1000, Loss: 0.4177
Epoch 689/1000, Loss: 0.4170
Epoch 690/1000, Loss: 0.4162
Epoch 691/1000, Loss: 0.4154
Epoch 692/1000, Loss: 0.4146
Epoch 693/1000, Loss: 0.4138
Epoch 694/1000, Loss: 0.4131
Epoch 695/1000, Loss: 0.4123
Epoch 696/1000, Loss: 0.4115
Epoch 697/1000, Loss: 0.4108
Epoch 698/1000, Loss: 0.4100
Epoch 699/1000, Loss: 0.4092
Epoch 700/1000, Loss: 0.4085
Epoch 701/1000, Loss: 0.4077
Epoch 702/1000, Loss: 0.4070
Epoch 703/1000, Loss: 0.4062
Epoch 704/1000, Loss: 0.4055
Epoch 705/1000, Loss: 0.4047
Epoch 706/1000, Loss: 0.4040
Epoch 707/1000, Loss: 0.4033
Epoch 708/1000, Loss: 0.4025
Epoch 709/1000, Loss: 0.4018
Epoch 710/1000, Loss: 0.4011
Epoch 711/1000, Loss: 0.4003
Epoch 712/1000, Loss: 0.3996
Epoch 713/1000, Loss: 0.3989
Epoch 714/1000, Loss: 0.3982
Epoch 715/1000, Loss: 0.3974
Epoch 716/1000, Loss: 0.3967
Epoch 717/1000, Loss: 0.3960
Epoch 718/1000, Loss: 0.3953
Epoch 719/1000, Loss: 0.3946
Epoch 720/1000, Loss: 0.3939
Epoch 721/1000, Loss: 0.3932
Epoch 722/1000, Loss: 0.3925
Epoch 723/1000, Loss: 0.3918
Epoch 724/1000, Loss: 0.3911
Epoch 725/1000, Loss: 0.3904
Epoch 726/1000, Loss: 0.3897
Epoch 727/1000, Loss: 0.3890
Epoch 728/1000, Loss: 0.3883
Epoch 729/1000, Loss: 0.3876
Epoch 730/1000, Loss: 0.3869
Epoch 731/1000, Loss: 0.3862
Epoch 732/1000, Loss: 0.3856
Epoch 733/1000, Loss: 0.3849
Epoch 734/1000, Loss: 0.3842
Epoch 735/1000, Loss: 0.3835
Epoch 736/1000, Loss: 0.3829
Epoch 737/1000, Loss: 0.3822
Epoch 738/1000, Loss: 0.3815
Epoch 739/1000, Loss: 0.3809
Epoch 740/1000, Loss: 0.3802
Epoch 741/1000, Loss: 0.3795
Epoch 742/1000, Loss: 0.3789
Epoch 743/1000, Loss: 0.3782
Epoch 744/1000, Loss: 0.3776
Epoch 745/1000, Loss: 0.3769
Epoch 746/1000, Loss: 0.3763
Epoch 747/1000, Loss: 0.3756
Epoch 748/1000, Loss: 0.3750
Epoch 749/1000, Loss: 0.3743
Epoch 750/1000, Loss: 0.3737
Epoch 751/1000, Loss: 0.3730
Epoch 752/1000, Loss: 0.3724
Epoch 753/1000, Loss: 0.3718
Epoch 754/1000, Loss: 0.3711
Epoch 755/1000, Loss: 0.3705
Epoch 756/1000, Loss: 0.3699
Epoch 757/1000, Loss: 0.3692
Epoch 758/1000, Loss: 0.3686
Epoch 759/1000, Loss: 0.3680
Epoch 760/1000, Loss: 0.3674
Epoch 761/1000, Loss: 0.3667
Epoch 762/1000, Loss: 0.3661
Epoch 763/1000, Loss: 0.3655
Epoch 764/1000, Loss: 0.3649
Epoch 765/1000, Loss: 0.3643
Epoch 766/1000, Loss: 0.3637
Epoch 767/1000, Loss: 0.3631
Epoch 768/1000, Loss: 0.3624
Epoch 769/1000, Loss: 0.3618
Epoch 770/1000, Loss: 0.3612
Epoch 771/1000, Loss: 0.3606
Epoch 772/1000, Loss: 0.3600
Epoch 773/1000, Loss: 0.3594
Epoch 774/1000, Loss: 0.3588
Epoch 775/1000, Loss: 0.3582
Epoch 776/1000, Loss: 0.3576
Epoch 777/1000, Loss: 0.3571
Epoch 778/1000, Loss: 0.3565
Epoch 779/1000, Loss: 0.3559
Epoch 780/1000, Loss: 0.3553
Epoch 781/1000, Loss: 0.3547
Epoch 782/1000, Loss: 0.3541
Epoch 783/1000, Loss: 0.3535
Epoch 784/1000, Loss: 0.3530
Epoch 785/1000, Loss: 0.3524
Epoch 786/1000, Loss: 0.3518
Epoch 787/1000, Loss: 0.3512
Epoch 788/1000, Loss: 0.3507
Epoch 789/1000, Loss: 0.3501
Epoch 790/1000, Loss: 0.3495
Epoch 791/1000, Loss: 0.3490
Epoch 792/1000, Loss: 0.3484
Epoch 793/1000, Loss: 0.3478
Epoch 794/1000, Loss: 0.3473
Epoch 795/1000, Loss: 0.3467
Epoch 796/1000, Loss: 0.3462
Epoch 797/1000, Loss: 0.3456
Epoch 798/1000, Loss: 0.3451
Epoch 799/1000, Loss: 0.3445
Epoch 800/1000, Loss: 0.3439
Epoch 801/1000, Loss: 0.3434
Epoch 802/1000, Loss: 0.3429
Epoch 803/1000, Loss: 0.3423
Epoch 804/1000, Loss: 0.3418
Epoch 805/1000, Loss: 0.3412
Epoch 806/1000, Loss: 0.3407
Epoch 807/1000, Loss: 0.3401
Epoch 808/1000, Loss: 0.3396
Epoch 809/1000, Loss: 0.3391
Epoch 810/1000, Loss: 0.3385
Epoch 811/1000, Loss: 0.3380
Epoch 812/1000, Loss: 0.3375
Epoch 813/1000, Loss: 0.3369
Epoch 814/1000, Loss: 0.3364
Epoch 815/1000, Loss: 0.3359
Epoch 816/1000, Loss: 0.3354
Epoch 817/1000, Loss: 0.3348
Epoch 818/1000, Loss: 0.3343
Epoch 819/1000, Loss: 0.3338
Epoch 820/1000, Loss: 0.3333
Epoch 821/1000, Loss: 0.3328
Epoch 822/1000, Loss: 0.3322
Epoch 823/1000, Loss: 0.3317
Epoch 824/1000, Loss: 0.3312
Epoch 825/1000, Loss: 0.3307
Epoch 826/1000, Loss: 0.3302
Epoch 827/1000, Loss: 0.3297
Epoch 828/1000, Loss: 0.3292
Epoch 829/1000, Loss: 0.3287
Epoch 830/1000, Loss: 0.3282
Epoch 831/1000, Loss: 0.3277
Epoch 832/1000, Loss: 0.3272
Epoch 833/1000, Loss: 0.3267
Epoch 834/1000, Loss: 0.3262
Epoch 835/1000, Loss: 0.3257
Epoch 836/1000, Loss: 0.3252
Epoch 837/1000, Loss: 0.3247
Epoch 838/1000, Loss: 0.3242
Epoch 839/1000, Loss: 0.3237
Epoch 840/1000, Loss: 0.3232
Epoch 841/1000, Loss: 0.3227
Epoch 842/1000, Loss: 0.3223
Epoch 843/1000, Loss: 0.3218
Epoch 844/1000, Loss: 0.3213
Epoch 845/1000, Loss: 0.3208
Epoch 846/1000, Loss: 0.3203
Epoch 847/1000, Loss: 0.3199
Epoch 848/1000, Loss: 0.3194
Epoch 849/1000, Loss: 0.3189
Epoch 850/1000, Loss: 0.3184
Epoch 851/1000, Loss: 0.3180
Epoch 852/1000, Loss: 0.3175
Epoch 853/1000, Loss: 0.3170
Epoch 854/1000, Loss: 0.3166
Epoch 855/1000, Loss: 0.3161
Epoch 856/1000, Loss: 0.3156
Epoch 857/1000, Loss: 0.3152
Epoch 858/1000, Loss: 0.3147
Epoch 859/1000, Loss: 0.3142
Epoch 860/1000, Loss: 0.3138
Epoch 861/1000, Loss: 0.3133
Epoch 862/1000, Loss: 0.3129
Epoch 863/1000, Loss: 0.3124
Epoch 864/1000, Loss: 0.3119
Epoch 865/1000, Loss: 0.3115
Epoch 866/1000, Loss: 0.3110
Epoch 867/1000, Loss: 0.3106
Epoch 868/1000, Loss: 0.3101
Epoch 869/1000, Loss: 0.3097
Epoch 870/1000, Loss: 0.3092
Epoch 871/1000, Loss: 0.3088
Epoch 872/1000, Loss: 0.3083
Epoch 873/1000, Loss: 0.3079
Epoch 874/1000, Loss: 0.3075
Epoch 875/1000, Loss: 0.3070
Epoch 876/1000, Loss: 0.3066
Epoch 877/1000, Loss: 0.3061
Epoch 878/1000, Loss: 0.3057
Epoch 879/1000, Loss: 0.3053
Epoch 880/1000, Loss: 0.3048
Epoch 881/1000, Loss: 0.3044
Epoch 882/1000, Loss: 0.3040
Epoch 883/1000, Loss: 0.3035
Epoch 884/1000, Loss: 0.3031
Epoch 885/1000, Loss: 0.3027
Epoch 886/1000, Loss: 0.3022
Epoch 887/1000, Loss: 0.3018
Epoch 888/1000, Loss: 0.3014
Epoch 889/1000, Loss: 0.3010
Epoch 890/1000, Loss: 0.3005
Epoch 891/1000, Loss: 0.3001
Epoch 892/1000, Loss: 0.2997
Epoch 893/1000, Loss: 0.2993
Epoch 894/1000, Loss: 0.2989
Epoch 895/1000, Loss: 0.2984
Epoch 896/1000, Loss: 0.2980
Epoch 897/1000, Loss: 0.2976
Epoch 898/1000, Loss: 0.2972
Epoch 899/1000, Loss: 0.2968
Epoch 900/1000, Loss: 0.2964
Epoch 901/1000, Loss: 0.2960
Epoch 902/1000, Loss: 0.2955
Epoch 903/1000, Loss: 0.2951
Epoch 904/1000, Loss: 0.2947
Epoch 905/1000, Loss: 0.2943
Epoch 906/1000, Loss: 0.2939
Epoch 907/1000, Loss: 0.2935
Epoch 908/1000, Loss: 0.2931
Epoch 909/1000, Loss: 0.2927
Epoch 910/1000, Loss: 0.2923
Epoch 911/1000, Loss: 0.2919
Epoch 912/1000, Loss: 0.2915
Epoch 913/1000, Loss: 0.2911
Epoch 914/1000, Loss: 0.2907
Epoch 915/1000, Loss: 0.2903
Epoch 916/1000, Loss: 0.2899
Epoch 917/1000, Loss: 0.2895
Epoch 918/1000, Loss: 0.2891
Epoch 919/1000, Loss: 0.2887
Epoch 920/1000, Loss: 0.2883
Epoch 921/1000, Loss: 0.2880
Epoch 922/1000, Loss: 0.2876
Epoch 923/1000, Loss: 0.2872
Epoch 924/1000, Loss: 0.2868
Epoch 925/1000, Loss: 0.2864
Epoch 926/1000, Loss: 0.2860
Epoch 927/1000, Loss: 0.2856
Epoch 928/1000, Loss: 0.2853
Epoch 929/1000, Loss: 0.2849
Epoch 930/1000, Loss: 0.2845
Epoch 931/1000, Loss: 0.2841
Epoch 932/1000, Loss: 0.2837
Epoch 933/1000, Loss: 0.2834
Epoch 934/1000, Loss: 0.2830
Epoch 935/1000, Loss: 0.2826
Epoch 936/1000, Loss: 0.2822
Epoch 937/1000, Loss: 0.2819
Epoch 938/1000, Loss: 0.2815
Epoch 939/1000, Loss: 0.2811
Epoch 940/1000, Loss: 0.2807
Epoch 941/1000, Loss: 0.2804
Epoch 942/1000, Loss: 0.2800
Epoch 943/1000, Loss: 0.2796
Epoch 944/1000, Loss: 0.2793
Epoch 945/1000, Loss: 0.2789
Epoch 946/1000, Loss: 0.2785
Epoch 947/1000, Loss: 0.2782
Epoch 948/1000, Loss: 0.2778
Epoch 949/1000, Loss: 0.2775
Epoch 950/1000, Loss: 0.2771
Epoch 951/1000, Loss: 0.2767
Epoch 952/1000, Loss: 0.2764
Epoch 953/1000, Loss: 0.2760
Epoch 954/1000, Loss: 0.2757
Epoch 955/1000, Loss: 0.2753
Epoch 956/1000, Loss: 0.2749
Epoch 957/1000, Loss: 0.2746
Epoch 958/1000, Loss: 0.2742
Epoch 959/1000, Loss: 0.2739
Epoch 960/1000, Loss: 0.2735
Epoch 961/1000, Loss: 0.2732
Epoch 962/1000, Loss: 0.2728
Epoch 963/1000, Loss: 0.2725
Epoch 964/1000, Loss: 0.2721
Epoch 965/1000, Loss: 0.2718
Epoch 966/1000, Loss: 0.2714
Epoch 967/1000, Loss: 0.2711
Epoch 968/1000, Loss: 0.2707
Epoch 969/1000, Loss: 0.2704
Epoch 970/1000, Loss: 0.2700
Epoch 971/1000, Loss: 0.2697
Epoch 972/1000, Loss: 0.2694
Epoch 973/1000, Loss: 0.2690
Epoch 974/1000, Loss: 0.2687
Epoch 975/1000, Loss: 0.2683
Epoch 976/1000, Loss: 0.2680
Epoch 977/1000, Loss: 0.2677
Epoch 978/1000, Loss: 0.2673
Epoch 979/1000, Loss: 0.2670
Epoch 980/1000, Loss: 0.2666
Epoch 981/1000, Loss: 0.2663
Epoch 982/1000, Loss: 0.2660
Epoch 983/1000, Loss: 0.2656
Epoch 984/1000, Loss: 0.2653
Epoch 985/1000, Loss: 0.2650
Epoch 986/1000, Loss: 0.2646
Epoch 987/1000, Loss: 0.2643
Epoch 988/1000, Loss: 0.2640
Epoch 989/1000, Loss: 0.2637
Epoch 990/1000, Loss: 0.2633
Epoch 991/1000, Loss: 0.2630
Epoch 992/1000, Loss: 0.2627
Epoch 993/1000, Loss: 0.2623
Epoch 994/1000, Loss: 0.2620
Epoch 995/1000, Loss: 0.2617
Epoch 996/1000, Loss: 0.2614
Epoch 997/1000, Loss: 0.2610
Epoch 998/1000, Loss: 0.2607
Epoch 999/1000, Loss: 0.2604
Epoch 1000/1000, Loss: 0.2601

MLP on MNIST with Sigmoid
Accuracy: 0.9325, F1-score: 0.9314
No description has been provided for this image
No description has been provided for this image
Test Loss: 0.2811
InΒ [15]:
rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(train_X, train_y)
rf_preds = rf.predict(test_X)
rf_acc = accuracy_score(test_y, rf_preds)
rf_f1 = f1_score(test_y, rf_preds, average='macro')
rf_cm = confusion_matrix(test_y, rf_preds)
summary("Random Forest", rf_acc, rf_f1, rf_cm)

log_reg = LogisticRegression(max_iter=1000)
log_reg.fit(train_X, train_y)
log_preds = log_reg.predict(test_X)
log_acc = accuracy_score(test_y, log_preds)
log_f1 = f1_score(test_y, log_preds, average='macro')
log_cm = confusion_matrix(test_y, log_preds)
summary("Logistic Regression", log_acc, log_f1, log_cm)
Random Forest
Accuracy: 0.9704, F1-score: 0.9702
No description has been provided for this image
Logistic Regression
Accuracy: 0.9256, F1-score: 0.9245
No description has been provided for this image
InΒ [16]:
print("MLP untrained with relu")
mlp_untrained_relu = MLP_relu().to(device)

visualize_tsne(mlp_untrained_relu, test_X, test_y, trained=False, device=device)
print("MLP untrained with sigmoid")
mlp_untrained_sigmoid = MLP_sigmoid().to(device)
visualize_tsne(mlp_untrained_sigmoid, test_X, test_y, trained=False, device=device)
print("MLP trained with relu")
visualize_tsne(mlp_relu, test_X, test_y, trained=True, device=device)
print("MLP trained with sigmoid")
visualize_tsne(mlp_sigmoid, test_X, test_y, trained=True, device=device)
MLP untrained with relu
No description has been provided for this image
MLP untrained with sigmoid
No description has been provided for this image
MLP trained with relu
No description has been provided for this image
MLP trained with sigmoid
No description has been provided for this image

t-SNE ComparisonΒΆ

t-SNE is a non linear dimensionality technique which is used to map high dimensional data to a 2d or 3d space for better visualisation. Observing the plots for the untrained MLPs (ReLU and Sigmoid), we see that there are no clusters and patterns. However, after training, there is clustering of each of the 10 classes, with very few outliers.

InΒ [17]:
(f_train_X, f_train_y), (f_test_X, f_test_y) = fashion_mnist.load_data()
f_test_X = torch.from_numpy(f_test_X.reshape(-1, 784)).float() / 255
f_test_y = torch.from_numpy(f_test_y).to(torch.int64)

preds_f, acc_f, f1_f, cm_f, test_loss_f = evaluate_model(mlp_relu, f_test_X, f_test_y, device)
summary("MLP on Fashion-MNIST", acc_f, f1_f, cm_f)
print(f"Test Loss: {test_loss_f:.4f}")

visualize_tsne(mlp_relu, f_test_X, f_test_y, trained=True, device=device)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
29515/29515 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26421880/26421880 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
5148/5148 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4422102/4422102 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
/tmp/ipykernel_6110/160832730.py:2: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
  f_test_X = torch.from_numpy(f_test_X.reshape(-1, 784)).float() / 255
MLP on Fashion-MNIST
Accuracy: 0.0662, F1-score: 0.0457
No description has been provided for this image
Test Loss: 30.0740
No description has been provided for this image
InΒ [18]:
print("\nTesting MLP (Sigmoid) on Fashion-MNIST...")

(f_train_X, f_train_y), (f_test_X, f_test_y) = fashion_mnist.load_data()
f_test_X = torch.from_numpy(f_test_X.reshape(-1, 784)).float() / 255
f_test_y = torch.from_numpy(f_test_y).to(torch.int64)

preds_f, acc_f, f1_f, cm_f, test_loss_f = evaluate_model(mlp_sigmoid, f_test_X, f_test_y, device)
summary("MLP on Fashion-MNIST", acc_f, f1_f, cm_f)
print(f"Test Loss: {test_loss_f:.4f}")

visualize_tsne(mlp_sigmoid, f_test_X, f_test_y, trained=True, device=device)
Testing MLP (Sigmoid) on Fashion-MNIST...

MLP on Fashion-MNIST
Accuracy: 0.1123, F1-score: 0.0623
No description has been provided for this image
Test Loss: 4.1060
No description has been provided for this image

Testing Trained MLP on Fashion MNISTΒΆ

The Fashion MNIST dataset is also similar to the MNIST dataset with it having images of size 28 x 28 and 10 output classes. When we use our trained MLPs on Fashion MNIST, we get very poor accuracy of around 11% with ReLU as activation and 6.6% wit Sigmoid as activation. We can get better results, by freezing the curren embeddings and adding 1 or 2 more layers and training it for a few epochs. From the t-SNE plot, there is no visible clustering of data which suggests no important features have been learnt. Moreover, both MLPs seem to be predict class 2 much more than any class.

Summary of Results for Question 1ΒΆ

Model Activation Function Accuracy F1 Score
MLP (ReLU) ReLU 0.9621 0.9618
MLP (Sigmoid) Sigmoid 0.9325 0.9314
Random Forest N/A 0.9704 0.9702
Logistic Regression Sigmoid 0.9256 0.9245

Fashion MNISTΒΆ

Model Activation Accuracy
MLP (ReLU) ReLU 0.1123
MLP (Sigmoid) Sigmoid 0.0662

We observe that the best performing model is the random forest model, closely followed by the MLP with ReLU ,MLP with Sigmoid with Logistic Regression performing the worst.This could be explained by the fact that there are only 2 hidden layers and the number of neurons are less. We also observe Logistic Regression performs significantly worse than the others as it can only learn a linear boundary. Moreover, these models do not give any importance to the spaciality of images, and treat each pixel as individiual features. Thus, we would be able to achieve higher accuracy with CNNS as they will be able to capture local features like corners and edges. We observe many misclassifications, specifically with 2 and 7, which could be attributed to the fact that there is no inherent ordering of pixels in the above models

Comparison of t-SNE plots of MNIST and Fashion MNISTΒΆ

In the t-SNE plot of the 20 neuron layer for the Fashion MNIST dataset, there is no proper clustering visible, which suggests that the required features haven't been learnt and the model is unable to distinguish between the various classes. However, there is a clear and visible clustering in the t-SNE plot for the model trained on the MNIST dataset. We can conclude that the feature representations learnt on MNIST do not transfer well to the Fashion MNIST dataset. We can achieve higher accuracy by freezing the current model, and adding 1-2 more layers and training the model for a few epochs on the Fashion MNIST dataset

InΒ [64]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.pool = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(32*13*13, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1,1,28,28)
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 32*13*13)
        x = F.relu(self.fc1(x))
        return self.fc2(x)
InΒ [65]:
cnn = CNN().to(device)
opt = optim.Adam(cnn.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
InΒ [66]:
from torch.utils.data import TensorDataset, DataLoader

batch_size = 64
train_dataset = TensorDataset(train_X, train_y)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

train_losses = []
for epoch in range(100):
    cnn.train()
    total_loss = 0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        opt.zero_grad()
        out = cnn(X_batch)
        loss = loss_fn(out, y_batch)
        loss.backward()
        opt.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)
    print(f"Epoch {epoch+1}/100, Loss: {avg_loss:.4f}")
Epoch 1/100, Loss: 0.2182
Epoch 2/100, Loss: 0.0685
Epoch 3/100, Loss: 0.0466
Epoch 4/100, Loss: 0.0336
Epoch 5/100, Loss: 0.0247
Epoch 6/100, Loss: 0.0187
Epoch 7/100, Loss: 0.0150
Epoch 8/100, Loss: 0.0118
Epoch 9/100, Loss: 0.0089
Epoch 10/100, Loss: 0.0065
Epoch 11/100, Loss: 0.0058
Epoch 12/100, Loss: 0.0063
Epoch 13/100, Loss: 0.0041
Epoch 14/100, Loss: 0.0050
Epoch 15/100, Loss: 0.0029
Epoch 16/100, Loss: 0.0044
Epoch 17/100, Loss: 0.0047
Epoch 18/100, Loss: 0.0029
Epoch 19/100, Loss: 0.0009
Epoch 20/100, Loss: 0.0041
Epoch 21/100, Loss: 0.0007
Epoch 22/100, Loss: 0.0027
Epoch 23/100, Loss: 0.0034
Epoch 24/100, Loss: 0.0008
Epoch 25/100, Loss: 0.0002
Epoch 26/100, Loss: 0.0037
Epoch 27/100, Loss: 0.0024
Epoch 28/100, Loss: 0.0007
Epoch 29/100, Loss: 0.0010
Epoch 30/100, Loss: 0.0030
Epoch 31/100, Loss: 0.0018
Epoch 32/100, Loss: 0.0004
Epoch 33/100, Loss: 0.0010
Epoch 34/100, Loss: 0.0000
Epoch 35/100, Loss: 0.0000
Epoch 36/100, Loss: 0.0000
Epoch 37/100, Loss: 0.0000
Epoch 38/100, Loss: 0.0000
Epoch 39/100, Loss: 0.0000
Epoch 40/100, Loss: 0.0000
Epoch 41/100, Loss: 0.0000
Epoch 42/100, Loss: 0.0023
Epoch 43/100, Loss: 0.0046
Epoch 44/100, Loss: 0.0008
Epoch 45/100, Loss: 0.0002
Epoch 46/100, Loss: 0.0000
Epoch 47/100, Loss: 0.0000
Epoch 48/100, Loss: 0.0000
Epoch 49/100, Loss: 0.0000
Epoch 50/100, Loss: 0.0000
Epoch 51/100, Loss: 0.0000
Epoch 52/100, Loss: 0.0000
Epoch 53/100, Loss: 0.0000
Epoch 54/100, Loss: 0.0000
Epoch 55/100, Loss: 0.0000
Epoch 56/100, Loss: 0.0000
Epoch 57/100, Loss: 0.0000
Epoch 58/100, Loss: 0.0000
Epoch 59/100, Loss: 0.0000
Epoch 60/100, Loss: 0.0000
Epoch 61/100, Loss: 0.0000
Epoch 62/100, Loss: 0.0000
Epoch 63/100, Loss: 0.0000
Epoch 64/100, Loss: 0.0000
Epoch 65/100, Loss: 0.0000
Epoch 66/100, Loss: 0.0090
Epoch 67/100, Loss: 0.0007
Epoch 68/100, Loss: 0.0001
Epoch 69/100, Loss: 0.0000
Epoch 70/100, Loss: 0.0000
Epoch 71/100, Loss: 0.0000
Epoch 72/100, Loss: 0.0000
Epoch 73/100, Loss: 0.0000
Epoch 74/100, Loss: 0.0000
Epoch 75/100, Loss: 0.0000
Epoch 76/100, Loss: 0.0000
Epoch 77/100, Loss: 0.0000
Epoch 78/100, Loss: 0.0000
Epoch 79/100, Loss: 0.0000
Epoch 80/100, Loss: 0.0000
Epoch 81/100, Loss: 0.0000
Epoch 82/100, Loss: 0.0000
Epoch 83/100, Loss: 0.0000
Epoch 84/100, Loss: 0.0000
Epoch 85/100, Loss: 0.0000
Epoch 86/100, Loss: 0.0000
Epoch 87/100, Loss: 0.0000
Epoch 88/100, Loss: 0.0000
Epoch 89/100, Loss: 0.0000
Epoch 90/100, Loss: 0.0000
Epoch 91/100, Loss: 0.0000
Epoch 92/100, Loss: 0.0000
Epoch 93/100, Loss: 0.0000
Epoch 94/100, Loss: 0.0000
Epoch 95/100, Loss: 0.0000
Epoch 96/100, Loss: 0.0059
Epoch 97/100, Loss: 0.0020
Epoch 98/100, Loss: 0.0010
Epoch 99/100, Loss: 0.0010
Epoch 100/100, Loss: 0.0023
InΒ [67]:
preds, acc_cnn, f1_cnn, cm_cnn, test_loss_cnn = evaluate_model(cnn, test_X, test_y, device)
params_cnn = sum(p.numel() for p in cnn.parameters())

start = time.time()
with torch.no_grad():
    _ = cnn(test_X[:512].to(device))
t_cnn = time.time() - start

summary("Simple CNN", acc_cnn, f1_cnn, cm_cnn, train_losses)
print(f"Params: {params_cnn:,}, Inference time: {t_cnn:.4f}s, Test Loss: {test_loss_cnn:.4f}")
Simple CNN
Accuracy: 0.9851, F1-score: 0.9850
No description has been provided for this image
No description has been provided for this image
Params: 693,962, Inference time: 0.0006s, Test Loss: 0.1167
InΒ [8]:
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def preprocess_batch(Xb):
    Xb = Xb.to(device)
    
    Xb = Xb.repeat(1, 3, 1, 1)
    Xb = F.interpolate(Xb, size=(224, 224), mode='bilinear', align_corners=False)
    
    mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
    Xb = (Xb - mean) / std  
    return Xb

def evaluate_model_in_batches(model, X, y, batch_size, device):
    model.eval()  
    preds = []
    true_labels = []
    total_loss = 0
    criterion = nn.CrossEntropyLoss()

    # Create batches
    for i in range(0, len(X), batch_size):
        Xb = X[i:i + batch_size].to(device)  
        yb = y[i:i + batch_size].to(device)
        Xb = preprocess_batch(Xb)
        with torch.no_grad():
            outputs = model(Xb)
            loss = criterion(outputs, yb)
            total_loss += loss.item() * len(Xb)
        preds.append(outputs.argmax(dim=1).cpu().numpy())  
        true_labels.append(yb.cpu().numpy())
    preds = np.concatenate(preds)
    true_labels = np.concatenate(true_labels)
    acc = (preds == true_labels).mean()
    f1 = f1_score(true_labels, preds, average='weighted')
    cm = confusion_matrix(true_labels, preds)
    
    avg_loss = total_loss / len(X)
    
    return preds, acc, f1, cm, avg_loss

Pretrained CNNS without fine-tuningΒΆ

InΒ [14]:
mobilenet = models.mobilenet_v2(weights="IMAGENET1K_V1")
mobilenet.classifier[1] = nn.Linear(1280, 10)
mobilenet = mobilenet.to(device)

efficientnet = models.efficientnet_b0(weights="IMAGENET1K_V1")
efficientnet.classifier[1] = nn.Linear(1280, 10)
efficientnet = efficientnet.to(device)
print("Evaluating MobileNetV2...")
preds_mob, acc_mob, f1_mob, cm_mob, loss_mob = evaluate_model_in_batches(
    mobilenet, test_X.reshape(-1,1,28,28), test_y, batch_size, device
)
summary("MobileNetV2", acc_mob, f1_mob, cm_mob)

params_mob = sum(p.numel() for p in mobilenet.parameters())
start = time.time()
with torch.no_grad():
    _ = mobilenet(preprocess_batch(test_X[:256].reshape(-1,1,28,28).to(device)))
t_mob = time.time() - start
print(f"Params: {params_mob:,}, Inference time: {t_mob:.4f}s, Test Loss: {loss_mob:.4f}")
print("\nEvaluating EfficientNet-B0.")
preds_eff, acc_eff, f1_eff, cm_eff, loss_eff = evaluate_model_in_batches(
    efficientnet, test_X.reshape(-1,1,28,28), test_y, batch_size, device
)
summary("EfficientNet-B0", acc_eff, f1_eff, cm_eff)

params_eff = sum(p.numel() for p in efficientnet.parameters())
start = time.time()
with torch.no_grad():
    _ = efficientnet(preprocess_batch(test_X[:256].reshape(-1,1,28,28).to(device)))
t_eff = time.time() - start
print(f"Params: {params_eff:,}, Inference time: {t_eff:.4f}s, Test Loss: {loss_eff:.4f}")
Evaluating MobileNetV2...

MobileNetV2
Accuracy: 0.0942, F1-score: 0.0218
No description has been provided for this image
Params: 2,236,682, Inference time: 0.0057s, Test Loss: 2.3812

Evaluating EfficientNet-B0.

EfficientNet-B0
Accuracy: 0.1007, F1-score: 0.0776
No description has been provided for this image
Params: 4,020,358, Inference time: 0.0082s, Test Loss: 2.2830

Fine-Tuning the Pretrained CNNsΒΆ

InΒ [68]:
mobilenet = models.mobilenet_v2(weights="IMAGENET1K_V1")
for p in mobilenet.parameters():
    p.requires_grad = False
in_features = mobilenet.classifier[-1].in_features
mobilenet.classifier = nn.Sequential(
    nn.Linear(in_features, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)
mobilenet = mobilenet.to(device)

opt_m = optim.Adam(filter(lambda p: p.requires_grad, mobilenet.parameters()), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

X_train = train_X.reshape(-1,1,28,28)
y_train = train_y
train_losses_m = []

print("\nTraining MobileNetV2 (fine-tuned with 2 layers)...")
for epoch in range(10):
    mobilenet.train()
    total_loss = 0
    for i in range(0, len(X_train), batch_size):
        Xb = X_train[i:i+batch_size].to(device)
        yb = y_train[i:i+batch_size].to(device)
        Xb = preprocess_batch(Xb)
        opt_m.zero_grad()
        out = mobilenet(Xb)
        loss = loss_fn(out, yb)
        loss.backward()
        opt_m.step()
        total_loss += loss.item() * len(Xb)
    avg_loss = total_loss / len(X_train)
    train_losses_m.append(avg_loss)
    print(f"Epoch {epoch+1}/10 Loss: {avg_loss:.4f}")
preds_m, acc_m, f1_m, cm_m, test_loss_m = evaluate_model_in_batches(mobilenet, test_X.reshape(-1,1,28,28), test_y, batch_size, device)
summary("MobileNetV2 (fine-tuned)", acc_m, f1_m, cm_m, train_losses=train_losses_m)
params_m = sum(p.numel() for p in mobilenet.parameters())
start = time.time()
with torch.no_grad():
    _ = mobilenet(preprocess_batch(test_X[:256].reshape(-1,1,28,28).to(device)))
t_m = time.time() - start
print(f"Params: {params_m:,}, Inference time: {t_m:.4f}s, Test Loss: {test_loss_m:.4f}")
efficient = models.efficientnet_b0(weights="IMAGENET1K_V1")
for p in efficient.parameters():
    p.requires_grad = False
last_layer = efficient.classifier[-1]
if isinstance(last_layer, nn.Linear):
    in_features = last_layer.in_features
else:
    in_features = last_layer[1].in_features
efficient.classifier = nn.Sequential(
    nn.Linear(in_features, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)
efficient = efficient.to(device)

opt_e = optim.Adam(filter(lambda p: p.requires_grad, efficient.parameters()), lr=1e-3)
train_losses_e = []
for epoch in range(10):
    efficient.train()
    total_loss = 0
    for i in range(0, len(X_train), batch_size):
        Xb = X_train[i:i+batch_size].to(device)
        yb = y_train[i:i+batch_size].to(device)
        Xb = preprocess_batch(Xb)
        opt_e.zero_grad()
        out = efficient(Xb)
        loss = loss_fn(out, yb)
        loss.backward()
        opt_e.step()
        total_loss += loss.item() * len(Xb)
    avg_loss = total_loss / len(X_train)
    train_losses_e.append(avg_loss)
    print(f"Epoch {epoch+1}/10 Loss: {avg_loss:.4f}")
preds_e, acc_e, f1_e, cm_e, test_loss_e = evaluate_model_in_batches(efficient, test_X.reshape(-1,1,28,28), test_y, batch_size, device)
summary("EfficientNet_B0 (fine-tuned)", acc_e, f1_e, cm_e, train_losses=train_losses_e)
params_e = sum(p.numel() for p in efficient.parameters())
start = time.time()
with torch.no_grad():
    _ = efficient(preprocess_batch(test_X[:256].reshape(-1,1,28,28).to(device)))
t_e = time.time() - start
print(f"Params: {params_e:,}, Inference time: {t_e:.4f}s, Test Loss: {test_loss_e:.4f}")
Training MobileNetV2 (fine-tuned with 2 layers)...
Epoch 1/10 Loss: 0.3012
Epoch 2/10 Loss: 0.1365
Epoch 3/10 Loss: 0.1097
Epoch 4/10 Loss: 0.0922
Epoch 5/10 Loss: 0.0826
Epoch 6/10 Loss: 0.0742
Epoch 7/10 Loss: 0.0709
Epoch 8/10 Loss: 0.0650
Epoch 9/10 Loss: 0.0561
Epoch 10/10 Loss: 0.0474

MobileNetV2 (fine-tuned)
Accuracy: 0.9721, F1-score: 0.9721
No description has been provided for this image
No description has been provided for this image
Params: 2,585,994, Inference time: 0.0059s, Test Loss: 0.1047
Epoch 1/10 Loss: 0.3345
Epoch 2/10 Loss: 0.1747
Epoch 3/10 Loss: 0.1461
Epoch 4/10 Loss: 0.1220
Epoch 5/10 Loss: 0.1081
Epoch 6/10 Loss: 0.0986
Epoch 7/10 Loss: 0.0910
Epoch 8/10 Loss: 0.0814
Epoch 9/10 Loss: 0.0784
Epoch 10/10 Loss: 0.0728

EfficientNet_B0 (fine-tuned)
Accuracy: 0.9730, F1-score: 0.9730
No description has been provided for this image
No description has been provided for this image
Params: 4,369,670, Inference time: 0.0084s, Test Loss: 0.0921

Summary of Question 2ΒΆ

Model Accuracy F1 Score Inference Time (s / 256 imgs) Parameters
CNN 0.9851 0.9850 0.0006 693962
MobileNetV2 with fine tuning 0.9721 0.9721 0.0059 2585994
EfficientNet_B0 with finetuning 0.9730 0.9730 0.0084 4369670
MobileNetV2 (pretrained) 0.0942 0.0218 0.0057 2236682
EfficientNet_B0 (pretrained) 0.1007 0.0776 0.0082 4020358

We observe that the CNN achieved the highest accuracy across all models, while also having the least inference time and parameters. The F1-Score was also highest for the CNN. Predicting with just the pretrained models gave us an accuracy which was worse than random guessing, confirmed by the confusion matrix. This could be because pretrained models were trained on the ImageNet dataset which primarily contains RGB images of objects and animals, which does not directly transfer to handwritten digit recognition.

We experimented by performing transfer learning by removing the last layer, and replacing it with 2 hidden layers and an output layer of 10 neurons. We then fine tuned the model by training it on the train dataset for 10 epochs. We have used MobileNetV2 and EfficicentNet_B0. Although these are supposed to be very powerful pretrained models with high number of parametrs, they still do not perform as well on the MNIST dataset as they are trained for recognizing features from complex RGB images. These features may not translate into the simple and limited features in the MNIST dataset. However, a simple CNN can learn these features from scratch very easily leading to better accuracy, while having much lesser number of parameters.

Overall, pretrained models still achieve higher accuracy than that of the other models like MLP, Random forest and Logistic Regression after finetuning. We could get higher accuracy with more finetuning and trainng for more number of epochs. We can also conclude that inference time is directly proportional to the number of parameters.